Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
     17 
     18 #include "llvm/IR/Function.h"
     19 #include "llvm/IR/IRBuilder.h"
     20 #include "llvm/IR/Intrinsics.h"
     21 #include "llvm/IR/Verifier.h"
     22 #include "llvm/Transforms/Utils/Cloning.h"
     23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     24 #include "tensorflow/core/lib/core/casts.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 
     27 namespace xla {
     28 namespace cpu {
     29 namespace runtime {
     30 
     31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
     32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
     33 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
     34 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
     35 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
     36 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
     37 
     38 namespace {
     39 llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
     40                                           llvm::StringRef function_name,
     41                                           int vector_width,
     42                                           bool enable_fast_math) {
     43   llvm::Function* vector_tanh_function = module->getFunction(function_name);
     44   if (vector_tanh_function == nullptr) {
     45     // If the function declaration is not present in the module, there can't be
     46     // any calls to resolve.  Don't emit the function in this case.
     47     return nullptr;
     48   }
     49 
     50   llvm::LLVMContext* context = &module->getContext();
     51 
     52   llvm::BasicBlock* vector_tanh_body =
     53       llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
     54 
     55   llvm::IRBuilder<> ir_builder(vector_tanh_body);
     56   llvm::FastMathFlags fast_math_flags;
     57   fast_math_flags.setFast();
     58   ir_builder.setFastMathFlags(fast_math_flags);
     59 
     60   VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32");
     61 
     62   llvm::Value* input = &*vector_tanh_function->arg_begin();
     63   CHECK_EQ(input->getType(), vsl.vector_type());
     64 
     65   // This implements the same rational interpolant as implemented in Eigen3.
     66   llvm::Value* input_clamped =
     67       vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0));
     68 
     69   std::array<float, 7> numerator_coeffs{
     70       -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
     71       5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
     72       4.89352455891786e-03f};
     73 
     74   std::array<float, 4> denominator_coeffs{
     75       1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
     76       4.89352518554385e-03f};
     77 
     78   llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped);
     79   llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0]));
     80   for (int i = 1; i < numerator_coeffs.size(); i++) {
     81     numerator =
     82         vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i]));
     83   }
     84 
     85   numerator = vsl.Mul(input_clamped, numerator);
     86 
     87   llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0]));
     88   for (int i = 1; i < denominator_coeffs.size(); i++) {
     89     denominator = vsl.MulAdd(input_squared, denominator,
     90                              GetIeeeF32(denominator_coeffs[i]));
     91   }
     92 
     93   llvm::Value* result = vsl.Div(numerator, denominator);
     94   ir_builder.CreateRet(result);
     95 
     96   DCHECK(!llvm::verifyFunction(*vector_tanh_function));
     97   return vector_tanh_function;
     98 }
     99 
    100 llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
    101                                          llvm::StringRef function_name,
    102                                          int vector_width,
    103                                          bool enable_fast_math) {
    104   llvm::Function* vector_exp_function = module->getFunction(function_name);
    105   if (vector_exp_function == nullptr) {
    106     // If the function declaration is not present in the module, there can't be
    107     // any calls to resolve.  Don't emit the function in this case.
    108     return nullptr;
    109   }
    110 
    111   llvm::LLVMContext* context = &module->getContext();
    112 
    113   llvm::BasicBlock* vector_exp_body =
    114       llvm::BasicBlock::Create(*context, "body", vector_exp_function);
    115 
    116   llvm::IRBuilder<> ir_builder(vector_exp_body);
    117   llvm::FastMathFlags fast_math_flags;
    118   fast_math_flags.setFast();
    119   ir_builder.setFastMathFlags(fast_math_flags);
    120 
    121   VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32");
    122 
    123   // This implements the same polynomial approximation as implemented in Eigen3.
    124 
    125   const llvm::APFloat half = GetIeeeF32(0.5);
    126   const llvm::APFloat one = GetIeeeF32(1.0);
    127 
    128   const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
    129   const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
    130 
    131   const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
    132   const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
    133   const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
    134 
    135   const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
    136   const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
    137   const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
    138   const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
    139   const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
    140   const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
    141 
    142   llvm::Value* input = &*vector_exp_function->arg_begin();
    143   llvm::Value* input_clamped =
    144       vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
    145   llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
    146   llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
    147   llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
    148   llvm::Value* x = vsl.Sub(input_clamped, tmp);
    149   x = vsl.Sub(x, z);
    150   z = vsl.Mul(x, x);
    151 
    152   llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
    153   y = vsl.MulAdd(y, x, cephes_exp_p2);
    154   y = vsl.MulAdd(y, x, cephes_exp_p3);
    155   y = vsl.MulAdd(y, x, cephes_exp_p4);
    156   y = vsl.MulAdd(y, x, cephes_exp_p5);
    157   y = vsl.MulAdd(y, z, x);
    158   y = vsl.Add(one, y);
    159 
    160   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
    161   // time so drop down to IRBuilder for this bit.
    162   llvm::Value* vector_constant_0x7f =
    163       ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
    164   llvm::Value* vector_constant_23 =
    165       ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
    166   llvm::Type* i32_vector_type =
    167       llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
    168   // fx is clamped so we don't have to worry about it being out of range for
    169   // i32.
    170   llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type);
    171   emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f);
    172   emm0 = ir_builder.CreateShl(emm0, vector_constant_23);
    173   llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type());
    174 
    175   llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input);
    176 
    177   ir_builder.CreateRet(result);
    178 
    179   DCHECK(!llvm::verifyFunction(*vector_exp_function));
    180   return vector_exp_function;
    181 }
    182 
    183 llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
    184                                          llvm::StringRef function_name,
    185                                          int vector_width,
    186                                          bool enable_fast_math) {
    187   llvm::Function* vector_log_function = module->getFunction(function_name);
    188   if (vector_log_function == nullptr) {
    189     // If the function declaration is not present in the module, there can't be
    190     // any calls to resolve.  Don't emit the function in this case.
    191     return nullptr;
    192   }
    193 
    194   llvm::LLVMContext* context = &module->getContext();
    195 
    196   llvm::BasicBlock* vector_log_body =
    197       llvm::BasicBlock::Create(*context, "body", vector_log_function);
    198 
    199   llvm::IRBuilder<> ir_builder(vector_log_body);
    200   llvm::FastMathFlags fast_math_flags;
    201   fast_math_flags.setFast();
    202   ir_builder.setFastMathFlags(fast_math_flags);
    203 
    204   llvm::Value* input = &*vector_log_function->arg_begin();
    205   VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
    206 
    207   const llvm::APFloat half = GetIeeeF32(0.5);
    208   const llvm::APFloat one = GetIeeeF32(1.0);
    209 
    210   // This implements the same polynomial approximation as implemented in Eigen3.
    211   // Returns NaN for x < 0, -INF for x = 0
    212   const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
    213   const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
    214   const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
    215   const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
    216   const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
    217   const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
    218   const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
    219   const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
    220   const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
    221   const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
    222   const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
    223   const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
    224 
    225   // The smallest non denormalized float number.
    226   const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
    227   const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
    228   const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
    229 
    230   // invalid_mask is set if x is negative or NaN (and therefore output
    231   // must be NaN).
    232   llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
    233   llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
    234 
    235   // Cut off denormalized stuff.
    236   input = vsl.Max(min_norm_pos, input);
    237 
    238   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
    239   // time so drop down to IRBuilder for this bit.
    240   llvm::Value* vector_constant_0x7f =
    241       ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
    242   llvm::Value* vector_constant_23 =
    243       ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
    244   llvm::Type* i32_vector_type =
    245       llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
    246 
    247   llvm::Value* emm0 = ir_builder.CreateLShr(
    248       ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23);
    249 
    250   // Keep only the fractional part.
    251   input = vsl.FloatAnd(input, inv_mant_mask);
    252   input = vsl.FloatOr(input, half);
    253 
    254   emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
    255   llvm::Value* e =
    256       vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
    257 
    258   // part2:
    259   //   if( x < SQRTHF ) {
    260   //     e -= 1;
    261   //     x = x + x - 1.0;
    262   //   } else { x = x - 1.0; }
    263   llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
    264   llvm::Value* tmp = vsl.FloatAnd(input, mask);
    265   input = vsl.Sub(input, one);
    266   e = vsl.Sub(e, vsl.FloatAnd(mask, one));
    267   input = vsl.Add(input, tmp);
    268 
    269   llvm::Value* x2 = vsl.Mul(input, input);
    270   llvm::Value* x3 = vsl.Mul(x2, input);
    271 
    272   llvm::Value *y, *y1, *y2;
    273   y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1);
    274   y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4);
    275   y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7);
    276   y = vsl.MulAdd(y, input, cephes_log_p2);
    277   y1 = vsl.MulAdd(y1, input, cephes_log_p5);
    278   y2 = vsl.MulAdd(y2, input, cephes_log_p8);
    279   y = vsl.MulAdd(y, x3, y1);
    280   y = vsl.MulAdd(y, x3, y2);
    281   y = vsl.Mul(y, x3);
    282 
    283   y1 = vsl.Mul(cephes_log_q1, e);
    284   tmp = vsl.Mul(half, x2);
    285   y = vsl.Add(y, y1);
    286   input = vsl.Sub(input, tmp);
    287   y2 = vsl.Mul(cephes_log_q2, e);
    288   input = vsl.Add(input, y);
    289   input = vsl.Add(input, y2);
    290 
    291   // Negative arg will be NAN, 0 will be -INF.
    292   llvm::Value* or_lhs =
    293       vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask));
    294   llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
    295   llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
    296 
    297   ir_builder.CreateRet(result);
    298 
    299   DCHECK(!llvm::verifyFunction(*vector_log_function));
    300   return vector_log_function;
    301 }
    302 }  // namespace
    303 
    304 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
    305   auto* tanh_v4f32 =
    306       EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
    307                                 /*vector_width=*/4, enable_fast_math);
    308   auto* tanh_v8f32 =
    309       EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
    310                                 /*vector_width=*/8, enable_fast_math);
    311 
    312   auto* exp_v4f32 =
    313       EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName,
    314                                /*vector_width=*/4, enable_fast_math);
    315   auto* exp_v8f32 =
    316       EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName,
    317                                /*vector_width=*/8, enable_fast_math);
    318 
    319   auto* log_v4f32 =
    320       EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName,
    321                                /*vector_width=*/4, enable_fast_math);
    322   auto* log_v8f32 =
    323       EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName,
    324                                /*vector_width=*/8, enable_fast_math);
    325 
    326   // Gather all the call sites, force inline them and then delete the vector
    327   // function bodies.
    328   //
    329   // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases?
    330 
    331   std::vector<llvm::CallInst*> calls_to_inline;
    332   for (auto* function :
    333        {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
    334     if (function != nullptr) {
    335       for (auto* user : function->users()) {
    336         calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
    337       }
    338     }
    339   }
    340 
    341   for (auto* call_to_inline : calls_to_inline) {
    342     llvm::InlineFunctionInfo inline_function_info;
    343     CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
    344   }
    345 
    346   for (auto* function :
    347        {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
    348     if (function != nullptr) {
    349       function->eraseFromParent();
    350     }
    351   }
    352 }
    353 
    354 }  // namespace runtime
    355 }  // namespace cpu
    356 }  // namespace xla
    357