Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
     17 
     18 #include "llvm/IR/Function.h"
     19 #include "llvm/IR/IRBuilder.h"
     20 #include "llvm/IR/Intrinsics.h"
     21 #include "llvm/IR/Verifier.h"
     22 #include "llvm/Transforms/Utils/Cloning.h"
     23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     24 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 
     27 namespace xla {
     28 namespace cpu {
     29 namespace runtime {
     30 
     31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
     32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
     33 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
     34 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
     35 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
     36 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
     37 
     38 namespace {
     39 
     40 // Replaces calls to the function `fn_name` with the code generated by
     41 // fn_body_generator.
     42 //
     43 // We assume that fn_name accepts either a scalar f32 or a vector of
     44 // vector_width f32s, and that fn_body_generator generates a function body with
     45 // the same inputs/outputs as fn_name.
     46 void RewriteCalls(
     47     llvm::Module* module, const char* fn_name,
     48     std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input,
     49                                int32 vector_width)>
     50         fn_body_generator,
     51     int32 vector_width, bool enable_fast_math) {
     52   llvm::Function* fn = module->getFunction(fn_name);
     53   if (fn == nullptr) {
     54     // If the function declaration is not present in the module, there can't be
     55     // any calls to resolve.  Don't emit the function in this case.
     56     return;
     57   }
     58 
     59   // Our task is to generate a function body for `fn`, but we can't generate a
     60   // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it
     61   // with a new function.
     62   if (fn->isIntrinsic()) {
     63     llvm::Function* new_fn = llvm::Function::Create(
     64         fn->getFunctionType(), llvm::GlobalValue::InternalLinkage,
     65         llvm::Twine("xla_impl.") + fn_name, module);
     66     fn->replaceAllUsesWith(new_fn);
     67     fn->eraseFromParent();
     68     fn = new_fn;
     69   }
     70 
     71   llvm::LLVMContext* context = &module->getContext();
     72 
     73   llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn);
     74   llvm::IRBuilder<> b(fn_body);
     75   llvm::FastMathFlags fast_math_flags;
     76   fast_math_flags.setFast(enable_fast_math);
     77   b.setFastMathFlags(fast_math_flags);
     78 
     79   llvm::Value* input = &*fn->arg_begin();
     80 
     81   // Upcast to vector type if input is a scalar.
     82   if (vector_width == 1) {
     83     llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1);
     84     input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input,
     85                                   uint64_t{0});
     86   }
     87 
     88   // Generate the vectorized code.
     89   CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
     90   llvm::Value* result = fn_body_generator(&b, input, vector_width);
     91 
     92   // Downcast result to scalar type if necessary.
     93   if (vector_width == 1) {
     94     result = b.CreateExtractElement(result, uint64_t{0});
     95   }
     96   b.CreateRet(result);
     97   DCHECK(!llvm::verifyFunction(*fn));
     98 
     99   // Force-inline `fn` into all of its callers and then delete `fn`.
    100   //
    101   // TODO(b/73081976): Should we avoid inlining these in some cases?
    102   std::vector<llvm::CallInst*> calls_to_inline;
    103   for (auto* user : fn->users()) {
    104     calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
    105   }
    106   for (auto* call_to_inline : calls_to_inline) {
    107     llvm::InlineFunctionInfo inline_function_info;
    108     CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
    109   }
    110   fn->eraseFromParent();
    111 }
    112 
    113 llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input,
    114                               int32 /*vector_width*/) {
    115   return llvm_ir::EmitFastTanh(b, input);
    116 }
    117 
    118 llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
    119                              int32 vector_width) {
    120   VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
    121 
    122   // This implements the same polynomial approximation as implemented in Eigen3.
    123 
    124   const llvm::APFloat half = GetIeeeF32(0.5);
    125   const llvm::APFloat one = GetIeeeF32(1.0);
    126 
    127   const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
    128   const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
    129 
    130   const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
    131   const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
    132   const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
    133 
    134   const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
    135   const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
    136   const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
    137   const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
    138   const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
    139   const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
    140 
    141   llvm::Value* input_clamped =
    142       vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
    143   llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
    144   llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
    145   llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
    146   llvm::Value* x = vsl.Sub(input_clamped, tmp);
    147   x = vsl.Sub(x, z);
    148   z = vsl.Mul(x, x);
    149 
    150   llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
    151   y = vsl.MulAdd(y, x, cephes_exp_p2);
    152   y = vsl.MulAdd(y, x, cephes_exp_p3);
    153   y = vsl.MulAdd(y, x, cephes_exp_p4);
    154   y = vsl.MulAdd(y, x, cephes_exp_p5);
    155   y = vsl.MulAdd(y, z, x);
    156   y = vsl.Add(one, y);
    157 
    158   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
    159   // time so drop down to IRBuilder for this bit.
    160   llvm::Value* vector_constant_0x7f =
    161       b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
    162   llvm::Value* vector_constant_23 =
    163       b->CreateVectorSplat(vector_width, b->getInt32(23));
    164   llvm::Type* i32_vector_type =
    165       llvm::VectorType::get(b->getInt32Ty(), vector_width);
    166   // fx is clamped so we don't have to worry about it being out of range for
    167   // i32.
    168   llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type);
    169   emm0 = b->CreateAdd(emm0, vector_constant_0x7f);
    170   emm0 = b->CreateShl(emm0, vector_constant_23);
    171   llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type());
    172 
    173   return vsl.Max(vsl.Mul(y, emm0_f32), input);
    174 }
    175 
    176 llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
    177                              int32 vector_width) {
    178   VectorSupportLibrary vsl(F32, vector_width, b, "log_f32");
    179 
    180   const llvm::APFloat half = GetIeeeF32(0.5);
    181   const llvm::APFloat one = GetIeeeF32(1.0);
    182 
    183   // This implements the same polynomial approximation as implemented in Eigen3.
    184   // Returns NaN for x < 0, -INF for x = 0
    185   const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
    186   const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
    187   const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
    188   const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
    189   const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
    190   const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
    191   const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
    192   const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
    193   const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
    194   const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
    195   const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
    196   const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
    197 
    198   // The smallest non denormalized float number.
    199   const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
    200   const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
    201   const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000);
    202   const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
    203 
    204   // invalid_mask is set if x is negative or NaN (and therefore output
    205   // must be NaN).
    206   llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
    207   llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
    208   llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
    209 
    210   // Cut off denormalized stuff.
    211   llvm::Value* tmp0 = vsl.Max(min_norm_pos, input);
    212 
    213   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
    214   // time so drop down to IRBuilder for this bit.
    215   llvm::Value* vector_constant_0x7f =
    216       b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
    217   llvm::Value* vector_constant_23 =
    218       b->CreateVectorSplat(vector_width, b->getInt32(23));
    219   llvm::Type* i32_vector_type =
    220       llvm::VectorType::get(b->getInt32Ty(), vector_width);
    221 
    222   llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type),
    223                                     vector_constant_23);
    224 
    225   // Keep only the fractional part.
    226   tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask);
    227   tmp0 = vsl.FloatOr(tmp0, half);
    228 
    229   emm0 = b->CreateSub(emm0, vector_constant_0x7f);
    230   llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type()));
    231 
    232   // part2:
    233   //   if( x < SQRTHF ) {
    234   //     e -= 1;
    235   //     x = x + x - 1.0;
    236   //   } else { x = x - 1.0; }
    237   llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF);
    238   llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask);
    239   tmp0 = vsl.Sub(tmp0, one);
    240   e = vsl.Sub(e, vsl.FloatAnd(mask, one));
    241   tmp0 = vsl.Add(tmp0, tmp1);
    242 
    243   llvm::Value* x2 = vsl.Mul(tmp0, tmp0);
    244   llvm::Value* x3 = vsl.Mul(x2, tmp0);
    245 
    246   llvm::Value *y, *y1, *y2;
    247   y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1);
    248   y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4);
    249   y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7);
    250   y = vsl.MulAdd(y, tmp0, cephes_log_p2);
    251   y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5);
    252   y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8);
    253   y = vsl.MulAdd(y, x3, y1);
    254   y = vsl.MulAdd(y, x3, y2);
    255   y = vsl.Mul(y, x3);
    256 
    257   y1 = vsl.Mul(cephes_log_q1, e);
    258   llvm::Value* tmp2 = vsl.Mul(half, x2);
    259   y = vsl.Add(y, y1);
    260   tmp0 = vsl.Sub(tmp0, tmp2);
    261   y2 = vsl.Mul(cephes_log_q2, e);
    262   tmp0 = vsl.Add(tmp0, y);
    263   tmp0 = vsl.Add(tmp0, y2);
    264 
    265   // Contains +/-inf where +/-inf is the correct answer, otherwise 0.
    266   llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf),
    267                                         vsl.FloatAnd(is_pos_inf_mask, pos_inf));
    268 
    269   // Contains a finite result or nan.  This is the correct answer only if both
    270   // result_minus_inf and result_pos_inf are both 0.
    271   //
    272   // (This implementation works because 0xffffffff is a nan.)
    273   llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask);
    274 
    275   // Combine the above into a final result.
    276   return vsl.FloatOr(result_inf,
    277                      vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask),
    278                                      result_finite_or_nan));
    279 }
    280 }  // namespace
    281 
    282 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
    283   // Curry some params to RewriteCalls.
    284   auto rewrite_calls =
    285       std::bind(RewriteCalls, module, std::placeholders::_1,
    286                 std::placeholders::_2, std::placeholders::_3, enable_fast_math);
    287 
    288   rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1);
    289   rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1);
    290   rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4);
    291   rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
    292 
    293   rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
    294   rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
    295   rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4);
    296   rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8);
    297 
    298   rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1);
    299   rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1);
    300   rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4);
    301   rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8);
    302 }
    303 
    304 }  // namespace runtime
    305 }  // namespace cpu
    306 }  // namespace xla
    307