1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 17 18 #include "absl/algorithm/container.h" 19 #include "llvm/Support/raw_ostream.h" 20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 22 23 namespace xla { 24 namespace cpu { 25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, 26 int64 vector_size, 27 llvm::IRBuilder<>* b, 28 std::string name) 29 : vector_size_(vector_size), 30 primitive_type_(primitive_type), 31 b_(b), 32 name_(std::move(name)) { 33 scalar_type_ = llvm_ir::PrimitiveTypeToIrType( 34 primitive_type, b_->GetInsertBlock()->getModule()); 35 scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); 36 vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); 37 vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); 38 } 39 40 static string TypeToString(llvm::Type* type) { 41 std::string o; 42 llvm::raw_string_ostream ostream(o); 43 type->print(ostream); 44 return ostream.str(); 45 } 46 47 void VectorSupportLibrary::AssertCorrectTypes( 48 std::initializer_list<llvm::Value*> values) { 49 for (llvm::Value* v : values) { 50 llvm::Type* type = v->getType(); 51 if (type != scalar_type() && type != vector_type()) { 52 LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or " 53 << TypeToString(vector_type()) << " but got " 54 << TypeToString(type); 55 } 56 } 57 } 58 59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { 60 AssertCorrectTypes({lhs, rhs}); 61 return MulInternal(lhs, rhs); 62 } 63 64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, 65 llvm::Value* rhs) { 66 if (scalar_type_->isFloatingPointTy()) { 67 return b()->CreateFMul(lhs, rhs, name()); 68 } else { 69 return b()->CreateMul(lhs, rhs, name()); 70 } 71 } 72 73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { 74 AssertCorrectTypes({lhs, rhs}); 75 return AddInternal(lhs, rhs); 76 } 77 78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { 79 AssertCorrectTypes({lhs, rhs}); 80 return b()->CreateFSub(lhs, rhs); 81 } 82 83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { 84 AssertCorrectTypes({lhs, rhs}); 85 if (scalar_type_->isFloatingPointTy()) { 86 return llvm_ir::EmitFloatMax(lhs, rhs, b_); 87 } else { 88 LOG(FATAL) << "Max for integers is unimplemented"; 89 } 90 } 91 92 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { 93 AssertCorrectTypes({a}); 94 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, 95 {a->getType()}, b()); 96 } 97 98 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { 99 AssertCorrectTypes({lhs, rhs}); 100 if (scalar_type_->isFloatingPointTy()) { 101 return b()->CreateFDiv(lhs, rhs, name()); 102 } else { 103 LOG(FATAL) << "Division for integers is unimplemented"; 104 } 105 } 106 107 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, 108 const llvm::APFloat& low, 109 const llvm::APFloat& high) { 110 AssertCorrectTypes({a}); 111 llvm::Type* type = a->getType(); 112 CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); 113 CHECK(scalar_type_->isFloatingPointTy()); 114 return llvm_ir::EmitFloatMin( 115 llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_), 116 GetConstantFloat(type, high), b_); 117 } 118 119 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, 120 llvm::Value* rhs) { 121 AssertCorrectTypes({lhs, rhs}); 122 return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name())); 123 } 124 125 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, 126 llvm::Value* rhs) { 127 AssertCorrectTypes({lhs, rhs}); 128 return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name())); 129 } 130 131 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, 132 llvm::Value* rhs) { 133 AssertCorrectTypes({lhs, rhs}); 134 return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name())); 135 } 136 137 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { 138 bool is_vector = llvm::isa<llvm::VectorType>(i1->getType()); 139 llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); 140 return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()), 141 is_vector ? vector_type() : scalar_type(), name()); 142 } 143 144 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { 145 CHECK(scalar_type()->isFloatingPointTy()); 146 const llvm::DataLayout& data_layout = 147 b()->GetInsertBlock()->getModule()->getDataLayout(); 148 int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); 149 llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); 150 if (vector) { 151 return llvm::VectorType::get(scalar_int_type, vector_size()); 152 } else { 153 return scalar_int_type; 154 } 155 } 156 157 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { 158 CHECK_EQ(x->getType(), scalar_type()); 159 return b()->CreateVectorSplat(vector_size(), x, name()); 160 } 161 162 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, 163 llvm::Value* rhs) { 164 AssertCorrectTypes({lhs, rhs}); 165 llvm::Type* int_type = 166 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 167 return b()->CreateBitCast( 168 b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()), 169 b()->CreateBitCast(rhs, int_type, name()), name()), 170 vector_type()); 171 } 172 173 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { 174 AssertCorrectTypes({lhs}); 175 llvm::Type* int_type = 176 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 177 return b()->CreateBitCast( 178 b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()), 179 vector_type()); 180 } 181 182 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { 183 AssertCorrectTypes({lhs, rhs}); 184 llvm::Type* int_type = 185 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 186 return b()->CreateBitCast( 187 b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()), 188 b()->CreateBitCast(rhs, int_type, name()), name()), 189 vector_type(), name()); 190 } 191 192 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, 193 llvm::Value* rhs) { 194 if (scalar_type_->isFloatingPointTy()) { 195 return b()->CreateFAdd(lhs, rhs, name()); 196 } else { 197 return b()->CreateAdd(lhs, rhs, name()); 198 } 199 } 200 201 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( 202 llvm::Value* base_pointer, llvm::Value* offset_elements) { 203 if (base_pointer->getType() != scalar_pointer_type()) { 204 base_pointer = 205 b()->CreateBitCast(base_pointer, scalar_pointer_type(), name()); 206 } 207 return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name()); 208 } 209 210 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { 211 if (pointer->getType() != vector_pointer_type()) { 212 pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name()); 213 } 214 return b()->CreateAlignedLoad( 215 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); 216 } 217 218 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { 219 if (pointer->getType() != scalar_pointer_type()) { 220 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); 221 } 222 return b()->CreateAlignedLoad( 223 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); 224 } 225 226 void VectorSupportLibrary::StoreVector(llvm::Value* value, 227 llvm::Value* pointer) { 228 AssertCorrectTypes({value}); 229 if (pointer->getType() != vector_pointer_type()) { 230 pointer = b()->CreateBitCast(pointer, vector_pointer_type()); 231 } 232 b()->CreateAlignedStore(value, pointer, 233 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); 234 } 235 236 void VectorSupportLibrary::StoreScalar(llvm::Value* value, 237 llvm::Value* pointer) { 238 AssertCorrectTypes({value}); 239 if (pointer->getType() != scalar_pointer_type()) { 240 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); 241 } 242 b()->CreateAlignedStore(value, pointer, 243 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); 244 } 245 246 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { 247 if (pointer->getType() != scalar_pointer_type()) { 248 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); 249 } 250 return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer), 251 name()); 252 } 253 254 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { 255 llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr); 256 for (unsigned i = vector_size(); i != 1; i >>= 1) { 257 // On every iteration, we shuffle half of the remaining lanes to the top 258 // half of shuffle, and add two old and the new vector. 259 260 for (unsigned j = 0; j < vector_size(); ++j) { 261 if (j < (i / 2)) { 262 mask[j] = b()->getInt32(i / 2 + j); 263 } else { 264 mask[j] = llvm::UndefValue::get(b()->getInt32Ty()); 265 } 266 } 267 268 llvm::Value* half_remaining_lanes = 269 b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), 270 llvm::ConstantVector::get(mask), ""); 271 vector = Add(vector, half_remaining_lanes); 272 } 273 274 return b()->CreateExtractElement(vector, b()->getInt32(0), name()); 275 } 276 277 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, 278 llvm::Value* rhs) { 279 CHECK_EQ(lhs->getType(), vector_type()); 280 CHECK_EQ(rhs->getType(), vector_type()); 281 CHECK_EQ(vector_size() % 2, 0); 282 283 llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b; 284 285 // Adding the values shuffled using mask_a and mask_b gives us the 286 // AVX-style horizontal add we want. The masks work as documented 287 // in https://llvm.org/docs/LangRef.html#shufflevector-instruction 288 // 289 // Here are the masks for vector_width() == 8: 290 // 291 // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 292 // --------+--+--+--+---+--+--+---+--- 293 // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 294 // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 295 // 296 // So, as an example, the value at lane 3 of the result vector is 297 // the result of adding lane 10 and lane 11 in the combined lhs++rhs 298 // vector, which are the lanes 2 and 3 in the rhs vector. 299 for (int i = 0; i < vector_size(); i += 2) { 300 int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); 301 mask_a.push_back(b()->getInt32(increment + i)); 302 mask_b.push_back(b()->getInt32(increment + i + 1)); 303 } 304 for (int i = 0; i < vector_size(); i += 2) { 305 int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); 306 mask_a.push_back(b()->getInt32(increment + i)); 307 mask_b.push_back(b()->getInt32(increment + i + 1)); 308 } 309 310 llvm::Value* shuffle_0 = 311 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a)); 312 llvm::Value* shuffle_1 = 313 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b)); 314 315 return Add(shuffle_0, shuffle_1); 316 } 317 318 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { 319 llvm::SmallVector<llvm::Constant*, 32> mask; 320 for (int i = 0; i < vector_size() / 2; i++) { 321 mask.push_back(b()->getInt32(i)); 322 } 323 324 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), 325 llvm::ConstantVector::get(mask)); 326 } 327 328 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { 329 llvm::SmallVector<llvm::Constant*, 32> mask; 330 for (int i = 0; i < vector_size() / 2; i++) { 331 mask.push_back(b()->getInt32(i + vector_size() / 2)); 332 } 333 334 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), 335 llvm::ConstantVector::get(mask)); 336 } 337 338 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums( 339 std::vector<llvm::Value*> vectors, llvm::Value* init_values) { 340 const int x86_avx_vector_elements = 341 TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size(); 342 if (vector_size() == x86_avx_vector_elements && 343 vectors.size() == x86_avx_vector_elements) { 344 return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); 345 } 346 347 std::vector<llvm::Value*> result; 348 std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), 349 [this](llvm::Value* vector) { return AddReduce(vector); }); 350 if (init_values) { 351 for (int64 i = 0, e = result.size(); i < e; i++) { 352 result[i] = Add(result[i], 353 b()->CreateExtractElement(init_values, b()->getInt32(i))); 354 } 355 } 356 return result; 357 } 358 359 std::vector<llvm::Value*> 360 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( 361 std::vector<llvm::Value*> vectors, llvm::Value* init_values) { 362 // vectors are N llvm vector values, each with N elements. 363 int64 lane_width = vectors.size(); 364 365 while (vectors.size() != 2) { 366 std::vector<llvm::Value*> new_vectors; 367 for (int i = 0; i < vectors.size(); i += 2) { 368 new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); 369 } 370 371 vectors = std::move(new_vectors); 372 } 373 374 llvm::Value* low = 375 AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); 376 if (init_values) { 377 low = AddInternal(ExtractLowHalf(init_values), low); 378 } 379 llvm::Value* high = 380 AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); 381 if (init_values) { 382 high = AddInternal(ExtractHighHalf(init_values), high); 383 } 384 385 // `low` has the first `lane_width / 2` horizontal reductions, and `high` has 386 // the next `lane_width / 2` horizontal reductions. 387 388 std::vector<llvm::Value*> results; 389 for (int i = 0; i < lane_width; i++) { 390 llvm::Value* scalar_result = 391 b()->CreateExtractElement(i < (lane_width / 2) ? low : high, 392 b()->getInt32(i % (lane_width / 2)), name()); 393 results.push_back(scalar_result); 394 } 395 396 return results; 397 } 398 399 llvm::Value* VectorSupportLibrary::GetZeroVector() { 400 return llvm::Constant::getNullValue(vector_type()); 401 } 402 403 llvm::Value* VectorSupportLibrary::GetZeroScalar() { 404 return llvm::Constant::getNullValue(scalar_type()); 405 } 406 407 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) { 408 alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_); 409 } 410 411 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); } 412 413 void LlvmVariable::Set(llvm::Value* new_value) { 414 b_->CreateStore(new_value, alloca_); 415 } 416 417 TileVariable::TileVariable(VectorSupportLibrary* vector_support, 418 std::vector<llvm::Value*> initial_value) { 419 for (llvm::Value* initial_vector_value : initial_value) { 420 storage_.emplace_back(vector_support, initial_vector_value); 421 } 422 } 423 424 std::vector<llvm::Value*> TileVariable::Get() const { 425 std::vector<llvm::Value*> result; 426 absl::c_transform(storage_, std::back_inserter(result), 427 [&](VectorVariable vect_var) { return vect_var.Get(); }); 428 return result; 429 } 430 431 void TileVariable::Set(absl::Span<llvm::Value* const> value) { 432 CHECK_EQ(value.size(), storage_.size()); 433 for (int64 i = 0, e = value.size(); i < e; i++) { 434 storage_[i].Set(value[i]); 435 } 436 } 437 438 } // namespace cpu 439 } // namespace xla 440