Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     17 
     18 #include "absl/algorithm/container.h"
     19 #include "llvm/Support/raw_ostream.h"
     20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     22 
     23 namespace xla {
     24 namespace cpu {
     25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
     26                                            int64 vector_size,
     27                                            llvm::IRBuilder<>* b,
     28                                            std::string name)
     29     : vector_size_(vector_size),
     30       primitive_type_(primitive_type),
     31       b_(b),
     32       name_(std::move(name)) {
     33   scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
     34       primitive_type, b_->GetInsertBlock()->getModule());
     35   scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
     36   vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
     37   vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
     38 }
     39 
     40 static string TypeToString(llvm::Type* type) {
     41   std::string o;
     42   llvm::raw_string_ostream ostream(o);
     43   type->print(ostream);
     44   return ostream.str();
     45 }
     46 
     47 void VectorSupportLibrary::AssertCorrectTypes(
     48     std::initializer_list<llvm::Value*> values) {
     49   for (llvm::Value* v : values) {
     50     llvm::Type* type = v->getType();
     51     if (type != scalar_type() && type != vector_type()) {
     52       LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
     53                  << TypeToString(vector_type()) << " but got "
     54                  << TypeToString(type);
     55     }
     56   }
     57 }
     58 
     59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
     60   AssertCorrectTypes({lhs, rhs});
     61   return MulInternal(lhs, rhs);
     62 }
     63 
     64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
     65                                                llvm::Value* rhs) {
     66   if (scalar_type_->isFloatingPointTy()) {
     67     return b()->CreateFMul(lhs, rhs, name());
     68   } else {
     69     return b()->CreateMul(lhs, rhs, name());
     70   }
     71 }
     72 
     73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
     74   AssertCorrectTypes({lhs, rhs});
     75   return AddInternal(lhs, rhs);
     76 }
     77 
     78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
     79   AssertCorrectTypes({lhs, rhs});
     80   return b()->CreateFSub(lhs, rhs);
     81 }
     82 
     83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
     84   AssertCorrectTypes({lhs, rhs});
     85   if (scalar_type_->isFloatingPointTy()) {
     86     return llvm_ir::EmitFloatMax(lhs, rhs, b_);
     87   } else {
     88     LOG(FATAL) << "Max for integers is unimplemented";
     89   }
     90 }
     91 
     92 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
     93   AssertCorrectTypes({a});
     94   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
     95                                       {a->getType()}, b());
     96 }
     97 
     98 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
     99   AssertCorrectTypes({lhs, rhs});
    100   if (scalar_type_->isFloatingPointTy()) {
    101     return b()->CreateFDiv(lhs, rhs, name());
    102   } else {
    103     LOG(FATAL) << "Division for integers is unimplemented";
    104   }
    105 }
    106 
    107 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
    108                                          const llvm::APFloat& low,
    109                                          const llvm::APFloat& high) {
    110   AssertCorrectTypes({a});
    111   llvm::Type* type = a->getType();
    112   CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
    113   CHECK(scalar_type_->isFloatingPointTy());
    114   return llvm_ir::EmitFloatMin(
    115       llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
    116       GetConstantFloat(type, high), b_);
    117 }
    118 
    119 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
    120                                               llvm::Value* rhs) {
    121   AssertCorrectTypes({lhs, rhs});
    122   return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
    123 }
    124 
    125 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
    126                                                llvm::Value* rhs) {
    127   AssertCorrectTypes({lhs, rhs});
    128   return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
    129 }
    130 
    131 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
    132                                                llvm::Value* rhs) {
    133   AssertCorrectTypes({lhs, rhs});
    134   return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
    135 }
    136 
    137 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
    138   bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
    139   llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
    140   return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
    141                             is_vector ? vector_type() : scalar_type(), name());
    142 }
    143 
    144 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
    145   CHECK(scalar_type()->isFloatingPointTy());
    146   const llvm::DataLayout& data_layout =
    147       b()->GetInsertBlock()->getModule()->getDataLayout();
    148   int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
    149   llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
    150   if (vector) {
    151     return llvm::VectorType::get(scalar_int_type, vector_size());
    152   } else {
    153     return scalar_int_type;
    154   }
    155 }
    156 
    157 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
    158   CHECK_EQ(x->getType(), scalar_type());
    159   return b()->CreateVectorSplat(vector_size(), x, name());
    160 }
    161 
    162 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
    163                                             llvm::Value* rhs) {
    164   AssertCorrectTypes({lhs, rhs});
    165   llvm::Type* int_type =
    166       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    167   return b()->CreateBitCast(
    168       b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
    169                      b()->CreateBitCast(rhs, int_type, name()), name()),
    170       vector_type());
    171 }
    172 
    173 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
    174   AssertCorrectTypes({lhs});
    175   llvm::Type* int_type =
    176       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    177   return b()->CreateBitCast(
    178       b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
    179       vector_type());
    180 }
    181 
    182 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
    183   AssertCorrectTypes({lhs, rhs});
    184   llvm::Type* int_type =
    185       IntegerTypeForFloatSize(lhs->getType() == vector_type());
    186   return b()->CreateBitCast(
    187       b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
    188                     b()->CreateBitCast(rhs, int_type, name()), name()),
    189       vector_type(), name());
    190 }
    191 
    192 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
    193                                                llvm::Value* rhs) {
    194   if (scalar_type_->isFloatingPointTy()) {
    195     return b()->CreateFAdd(lhs, rhs, name());
    196   } else {
    197     return b()->CreateAdd(lhs, rhs, name());
    198   }
    199 }
    200 
    201 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
    202     llvm::Value* base_pointer, llvm::Value* offset_elements) {
    203   if (base_pointer->getType() != scalar_pointer_type()) {
    204     base_pointer =
    205         b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
    206   }
    207   return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
    208 }
    209 
    210 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
    211   if (pointer->getType() != vector_pointer_type()) {
    212     pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
    213   }
    214   return b()->CreateAlignedLoad(
    215       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
    216 }
    217 
    218 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
    219   if (pointer->getType() != scalar_pointer_type()) {
    220     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
    221   }
    222   return b()->CreateAlignedLoad(
    223       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
    224 }
    225 
    226 void VectorSupportLibrary::StoreVector(llvm::Value* value,
    227                                        llvm::Value* pointer) {
    228   AssertCorrectTypes({value});
    229   if (pointer->getType() != vector_pointer_type()) {
    230     pointer = b()->CreateBitCast(pointer, vector_pointer_type());
    231   }
    232   b()->CreateAlignedStore(value, pointer,
    233                           ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
    234 }
    235 
    236 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
    237                                        llvm::Value* pointer) {
    238   AssertCorrectTypes({value});
    239   if (pointer->getType() != scalar_pointer_type()) {
    240     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
    241   }
    242   b()->CreateAlignedStore(value, pointer,
    243                           ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
    244 }
    245 
    246 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
    247   if (pointer->getType() != scalar_pointer_type()) {
    248     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
    249   }
    250   return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
    251                                 name());
    252 }
    253 
    254 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
    255   llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
    256   for (unsigned i = vector_size(); i != 1; i >>= 1) {
    257     // On every iteration, we shuffle half of the remaining lanes to the top
    258     // half of shuffle, and add two old and the new vector.
    259 
    260     for (unsigned j = 0; j < vector_size(); ++j) {
    261       if (j < (i / 2)) {
    262         mask[j] = b()->getInt32(i / 2 + j);
    263       } else {
    264         mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
    265       }
    266     }
    267 
    268     llvm::Value* half_remaining_lanes =
    269         b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
    270                                  llvm::ConstantVector::get(mask), "");
    271     vector = Add(vector, half_remaining_lanes);
    272   }
    273 
    274   return b()->CreateExtractElement(vector, b()->getInt32(0), name());
    275 }
    276 
    277 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
    278                                                          llvm::Value* rhs) {
    279   CHECK_EQ(lhs->getType(), vector_type());
    280   CHECK_EQ(rhs->getType(), vector_type());
    281   CHECK_EQ(vector_size() % 2, 0);
    282 
    283   llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
    284 
    285   // Adding the values shuffled using mask_a and mask_b gives us the
    286   // AVX-style horizontal add we want.  The masks work as documented
    287   // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
    288   //
    289   // Here are the masks for vector_width() == 8:
    290   //
    291   //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
    292   //   --------+--+--+--+---+--+--+---+---
    293   //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
    294   //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
    295   //
    296   // So, as an example, the value at lane 3 of the result vector is
    297   // the result of adding lane 10 and lane 11 in the combined lhs++rhs
    298   // vector, which are the lanes 2 and 3 in the rhs vector.
    299   for (int i = 0; i < vector_size(); i += 2) {
    300     int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
    301     mask_a.push_back(b()->getInt32(increment + i));
    302     mask_b.push_back(b()->getInt32(increment + i + 1));
    303   }
    304   for (int i = 0; i < vector_size(); i += 2) {
    305     int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
    306     mask_a.push_back(b()->getInt32(increment + i));
    307     mask_b.push_back(b()->getInt32(increment + i + 1));
    308   }
    309 
    310   llvm::Value* shuffle_0 =
    311       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
    312   llvm::Value* shuffle_1 =
    313       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
    314 
    315   return Add(shuffle_0, shuffle_1);
    316 }
    317 
    318 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
    319   llvm::SmallVector<llvm::Constant*, 32> mask;
    320   for (int i = 0; i < vector_size() / 2; i++) {
    321     mask.push_back(b()->getInt32(i));
    322   }
    323 
    324   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
    325                                   llvm::ConstantVector::get(mask));
    326 }
    327 
    328 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
    329   llvm::SmallVector<llvm::Constant*, 32> mask;
    330   for (int i = 0; i < vector_size() / 2; i++) {
    331     mask.push_back(b()->getInt32(i + vector_size() / 2));
    332   }
    333 
    334   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
    335                                   llvm::ConstantVector::get(mask));
    336 }
    337 
    338 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
    339     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
    340   const int x86_avx_vector_elements =
    341       TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
    342   if (vector_size() == x86_avx_vector_elements &&
    343       vectors.size() == x86_avx_vector_elements) {
    344     return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
    345   }
    346 
    347   std::vector<llvm::Value*> result;
    348   std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
    349                  [this](llvm::Value* vector) { return AddReduce(vector); });
    350   if (init_values) {
    351     for (int64 i = 0, e = result.size(); i < e; i++) {
    352       result[i] = Add(result[i],
    353                       b()->CreateExtractElement(init_values, b()->getInt32(i)));
    354     }
    355   }
    356   return result;
    357 }
    358 
    359 std::vector<llvm::Value*>
    360 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
    361     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
    362   // vectors are N llvm vector values, each with N elements.
    363   int64 lane_width = vectors.size();
    364 
    365   while (vectors.size() != 2) {
    366     std::vector<llvm::Value*> new_vectors;
    367     for (int i = 0; i < vectors.size(); i += 2) {
    368       new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
    369     }
    370 
    371     vectors = std::move(new_vectors);
    372   }
    373 
    374   llvm::Value* low =
    375       AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
    376   if (init_values) {
    377     low = AddInternal(ExtractLowHalf(init_values), low);
    378   }
    379   llvm::Value* high =
    380       AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
    381   if (init_values) {
    382     high = AddInternal(ExtractHighHalf(init_values), high);
    383   }
    384 
    385   // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
    386   // the next `lane_width / 2` horizontal reductions.
    387 
    388   std::vector<llvm::Value*> results;
    389   for (int i = 0; i < lane_width; i++) {
    390     llvm::Value* scalar_result =
    391         b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
    392                                   b()->getInt32(i % (lane_width / 2)), name());
    393     results.push_back(scalar_result);
    394   }
    395 
    396   return results;
    397 }
    398 
    399 llvm::Value* VectorSupportLibrary::GetZeroVector() {
    400   return llvm::Constant::getNullValue(vector_type());
    401 }
    402 
    403 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
    404   return llvm::Constant::getNullValue(scalar_type());
    405 }
    406 
    407 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
    408   alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
    409 }
    410 
    411 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
    412 
    413 void LlvmVariable::Set(llvm::Value* new_value) {
    414   b_->CreateStore(new_value, alloca_);
    415 }
    416 
    417 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
    418                            std::vector<llvm::Value*> initial_value) {
    419   for (llvm::Value* initial_vector_value : initial_value) {
    420     storage_.emplace_back(vector_support, initial_vector_value);
    421   }
    422 }
    423 
    424 std::vector<llvm::Value*> TileVariable::Get() const {
    425   std::vector<llvm::Value*> result;
    426   absl::c_transform(storage_, std::back_inserter(result),
    427                     [&](VectorVariable vect_var) { return vect_var.Get(); });
    428   return result;
    429 }
    430 
    431 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
    432   CHECK_EQ(value.size(), storage_.size());
    433   for (int64 i = 0, e = value.size(); i < e; i++) {
    434     storage_[i].Set(value[i]);
    435   }
    436 }
    437 
    438 }  // namespace cpu
    439 }  // namespace xla
    440