Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     17 
     18 #include "llvm/Support/raw_ostream.h"
     19 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     21 
     22 namespace xla {
     23 namespace cpu {
     24 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
     25                                            int64 vector_size,
     26                                            llvm::IRBuilder<>* ir_builder,
     27                                            std::string name)
     28     : vector_size_(vector_size),
     29       primitive_type_(primitive_type),
     30       ir_builder_(ir_builder),
     31       name_(std::move(name)) {
     32   scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
     33       primitive_type, ir_builder_->GetInsertBlock()->getModule());
     34   scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
     35   vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
     36   vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
     37 }
     38 
     39 static string TypeToString(llvm::Type* type) {
     40   std::string o;
     41   llvm::raw_string_ostream ostream(o);
     42   type->print(ostream);
     43   return ostream.str();
     44 }
     45 
     46 void VectorSupportLibrary::AssertCorrectTypes(
     47     std::initializer_list<llvm::Value*> values) {
     48   for (llvm::Value* v : values) {
     49     llvm::Type* type = v->getType();
     50     if (type != scalar_type() && type != vector_type()) {
     51       LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
     52                  << TypeToString(vector_type()) << " but got "
     53                  << TypeToString(type);
     54     }
     55   }
     56 }
     57 
     58 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
     59   AssertCorrectTypes({lhs, rhs});
     60   return MulInternal(lhs, rhs);
     61 }
     62 
     63 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
     64                                                llvm::Value* rhs) {
     65   if (scalar_type_->isFloatingPointTy()) {
     66     return ir_builder()->CreateFMul(lhs, rhs, name());
     67   } else {
     68     return ir_builder()->CreateMul(lhs, rhs, name());
     69   }
     70 }
     71 
     72 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
     73   AssertCorrectTypes({lhs, rhs});
     74   return AddInternal(lhs, rhs);
     75 }
     76 
     77 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
     78   AssertCorrectTypes({lhs, rhs});
     79   return ir_builder()->CreateFSub(lhs, rhs);
     80 }
     81 
     82 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
     83   AssertCorrectTypes({lhs, rhs});
     84   if (scalar_type_->isFloatingPointTy()) {
     85     return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_);
     86   } else {
     87     LOG(FATAL) << "Max for integers is unimplemented";
     88   }
     89 }
     90 
     91 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
     92   AssertCorrectTypes({a});
     93   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
     94                                       {a->getType()}, ir_builder());
     95 }
     96 
     97 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
     98   AssertCorrectTypes({lhs, rhs});
     99   if (scalar_type_->isFloatingPointTy()) {
    100     return ir_builder()->CreateFDiv(lhs, rhs, name());
    101   } else {
    102     LOG(FATAL) << "Division for integers is unimplemented";
    103   }
    104 }
    105 
    106 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
    107                                          const llvm::APFloat& low,
    108                                          const llvm::APFloat& high) {
    109   AssertCorrectTypes({a});
    110   llvm::Type* type = a->getType();
    111   CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
    112   CHECK(scalar_type_->isFloatingPointTy());
    113   return llvm_ir::EmitFloatMin(
    114       llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
    115       GetConstantFloat(type, high), ir_builder_);
    116 }
    117 
    118 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
    119                                               llvm::Value* rhs) {
    120   AssertCorrectTypes({lhs, rhs});
    121   return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
    122 }
    123 
    124 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
    125                                                llvm::Value* rhs) {
    126   AssertCorrectTypes({lhs, rhs});
    127   return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
    128 }
    129 
    130 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
    131                                                llvm::Value* rhs) {
    132   AssertCorrectTypes({lhs, rhs});
    133   return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
    134 }
    135 
    136 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
    137   bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
    138   llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
    139   return ir_builder()->CreateBitCast(
    140       ir_builder()->CreateSExt(i1, integer_type, name()),
    141       is_vector ? vector_type() : scalar_type(), name());
    142 }
    143 
    144 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
    145   CHECK(scalar_type()->isFloatingPointTy());
    146   const llvm::DataLayout& data_layout =
    147       ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
    148   int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
    149   llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
    150   if (vector) {
    151     return llvm::VectorType::get(scalar_int_type, vector_size());
    152   } else {
    153     return scalar_int_type;
    154   }
    155 }
    156 
    157 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
    158   CHECK_EQ(x->getType(), scalar_type());
    159   return ir_builder()->CreateVectorSplat(vector_size(), x, name());
    160 }
    161 
    162 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
    163                                             llvm::Value* rhs) {
    164   AssertCorrectTypes({lhs, rhs});
    165   llvm::Type* int_type =
    166       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    167   return ir_builder()->CreateBitCast(
    168       ir_builder()->CreateAnd(
    169           ir_builder()->CreateBitCast(lhs, int_type, name()),
    170           ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
    171       vector_type());
    172 }
    173 
    174 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
    175   AssertCorrectTypes({lhs});
    176   llvm::Type* int_type =
    177       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    178   return ir_builder()->CreateBitCast(
    179       ir_builder()->CreateNot(
    180           ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
    181       vector_type());
    182 }
    183 
    184 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
    185   AssertCorrectTypes({lhs, rhs});
    186   llvm::Type* int_type =
    187       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    188   return ir_builder()->CreateBitCast(
    189       ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
    190                              ir_builder()->CreateBitCast(rhs, int_type, name()),
    191                              name()),
    192       vector_type(), name());
    193 }
    194 
    195 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
    196                                                llvm::Value* rhs) {
    197   if (scalar_type_->isFloatingPointTy()) {
    198     return ir_builder()->CreateFAdd(lhs, rhs, name());
    199   } else {
    200     return ir_builder()->CreateAdd(lhs, rhs, name());
    201   }
    202 }
    203 
    204 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
    205     llvm::Value* base_pointer, llvm::Value* offset_elements) {
    206   if (base_pointer->getType() != scalar_pointer_type()) {
    207     base_pointer = ir_builder()->CreateBitCast(base_pointer,
    208                                                scalar_pointer_type(), name());
    209   }
    210   return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
    211                                          name());
    212 }
    213 
    214 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
    215   if (pointer->getType() != vector_pointer_type()) {
    216     pointer =
    217         ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
    218   }
    219   return ir_builder()->CreateAlignedLoad(
    220       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
    221 }
    222 
    223 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
    224   if (pointer->getType() != scalar_pointer_type()) {
    225     pointer =
    226         ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
    227   }
    228   return ir_builder()->CreateAlignedLoad(
    229       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
    230 }
    231 
    232 void VectorSupportLibrary::StoreVector(llvm::Value* value,
    233                                        llvm::Value* pointer) {
    234   AssertCorrectTypes({value});
    235   if (pointer->getType() != vector_pointer_type()) {
    236     pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
    237   }
    238   ir_builder()->CreateAlignedStore(
    239       value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
    240 }
    241 
    242 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
    243                                        llvm::Value* pointer) {
    244   AssertCorrectTypes({value});
    245   if (pointer->getType() != scalar_pointer_type()) {
    246     pointer =
    247         ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
    248   }
    249   ir_builder()->CreateAlignedStore(
    250       value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
    251 }
    252 
    253 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
    254   if (pointer->getType() != scalar_pointer_type()) {
    255     pointer =
    256         ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
    257   }
    258   return ir_builder()->CreateVectorSplat(
    259       vector_size(), ir_builder()->CreateLoad(pointer), name());
    260 }
    261 
    262 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
    263   llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
    264   for (unsigned i = vector_size(); i != 1; i >>= 1) {
    265     // On every iteration, we shuffle half of the remaining lanes to the top
    266     // half of shuffle, and add two old and the new vector.
    267 
    268     for (unsigned j = 0; j < vector_size(); ++j) {
    269       if (j < (i / 2)) {
    270         mask[j] = ir_builder()->getInt32(i / 2 + j);
    271       } else {
    272         mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
    273       }
    274     }
    275 
    276     llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
    277         vector, llvm::UndefValue::get(vector_type()),
    278         llvm::ConstantVector::get(mask), "");
    279     vector = Add(vector, half_remaining_lanes);
    280   }
    281 
    282   return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
    283                                             name());
    284 }
    285 
    286 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
    287                                                          llvm::Value* rhs) {
    288   CHECK_EQ(lhs->getType(), vector_type());
    289   CHECK_EQ(rhs->getType(), vector_type());
    290   CHECK_EQ(vector_size() % 2, 0);
    291 
    292   llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
    293 
    294   // Adding the values shuffled using mask_a and mask_b gives us the
    295   // AVX-style horizontal add we want.  The masks work as documented
    296   // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
    297   //
    298   // Here are the masks for vector_width() == 8:
    299   //
    300   //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
    301   //   --------+--+--+--+---+--+--+---+---
    302   //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
    303   //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
    304   //
    305   // So, as an example, the value at lane 3 of the result vector is
    306   // the result of adding lane 10 and lane 11 in the combined lhs++rhs
    307   // vector, which are the lanes 2 and 3 in the rhs vector.
    308   for (int i = 0; i < vector_size(); i += 2) {
    309     int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
    310     mask_a.push_back(ir_builder()->getInt32(increment + i));
    311     mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
    312   }
    313   for (int i = 0; i < vector_size(); i += 2) {
    314     int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
    315     mask_a.push_back(ir_builder()->getInt32(increment + i));
    316     mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
    317   }
    318 
    319   llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector(
    320       lhs, rhs, llvm::ConstantVector::get(mask_a));
    321   llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector(
    322       lhs, rhs, llvm::ConstantVector::get(mask_b));
    323 
    324   return Add(shuffle_0, shuffle_1);
    325 }
    326 
    327 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
    328   llvm::SmallVector<llvm::Constant*, 32> mask;
    329   for (int i = 0; i < vector_size() / 2; i++) {
    330     mask.push_back(ir_builder()->getInt32(i));
    331   }
    332 
    333   return ir_builder()->CreateShuffleVector(vector,
    334                                            llvm::UndefValue::get(vector_type()),
    335                                            llvm::ConstantVector::get(mask));
    336 }
    337 
    338 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
    339   llvm::SmallVector<llvm::Constant*, 32> mask;
    340   for (int i = 0; i < vector_size() / 2; i++) {
    341     mask.push_back(ir_builder()->getInt32(i + vector_size() / 2));
    342   }
    343 
    344   return ir_builder()->CreateShuffleVector(vector,
    345                                            llvm::UndefValue::get(vector_type()),
    346                                            llvm::ConstantVector::get(mask));
    347 }
    348 
    349 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
    350     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
    351   const int x86_avx_vector_elements =
    352       TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
    353   if (vector_size() == x86_avx_vector_elements &&
    354       vectors.size() == x86_avx_vector_elements) {
    355     return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
    356   }
    357 
    358   std::vector<llvm::Value*> result;
    359   std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
    360                  [this](llvm::Value* vector) { return AddReduce(vector); });
    361   if (init_values) {
    362     for (int64 i = 0, e = result.size(); i < e; i++) {
    363       result[i] = Add(result[i], ir_builder()->CreateExtractElement(
    364                                      init_values, ir_builder()->getInt32(i)));
    365     }
    366   }
    367   return result;
    368 }
    369 
    370 std::vector<llvm::Value*>
    371 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
    372     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
    373   while (vectors.size() != 2) {
    374     std::vector<llvm::Value*> new_vectors;
    375     for (int i = 0; i < vectors.size(); i += 2) {
    376       new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
    377     }
    378 
    379     vectors = std::move(new_vectors);
    380   }
    381 
    382   llvm::Value* low =
    383       AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
    384   if (init_values) {
    385     low = AddInternal(ExtractLowHalf(init_values), low);
    386   }
    387   llvm::Value* high =
    388       AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
    389   if (init_values) {
    390     high = AddInternal(ExtractHighHalf(init_values), high);
    391   }
    392 
    393   std::vector<llvm::Value*> results;
    394   for (int i = 0; i < 8; i++) {
    395     llvm::Value* scalar_result = ir_builder()->CreateExtractElement(
    396         i < 4 ? low : high, ir_builder()->getInt32(i % 4), name());
    397     results.push_back(scalar_result);
    398   }
    399 
    400   return results;
    401 }
    402 
    403 llvm::Value* VectorSupportLibrary::GetZeroVector() {
    404   return llvm::Constant::getNullValue(vector_type());
    405 }
    406 
    407 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
    408   return llvm::Constant::getNullValue(scalar_type());
    409 }
    410 
    411 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
    412     : ir_builder_(ir_builder) {
    413   alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
    414 }
    415 
    416 llvm::Value* LlvmVariable::Get() const {
    417   return ir_builder_->CreateLoad(alloca_);
    418 }
    419 
    420 void LlvmVariable::Set(llvm::Value* new_value) {
    421   ir_builder_->CreateStore(new_value, alloca_);
    422 }
    423 }  // namespace cpu
    424 }  // namespace xla
    425