1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ 18 19 #include <string> 20 21 #include "absl/types/span.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/Value.h" 24 #include "tensorflow/compiler/xla/primitive_util.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 namespace cpu { 30 31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more 32 // obvious. 33 34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); } 35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) { 36 return llvm::APFloat(llvm::APFloat::IEEEsingle(), 37 llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value)); 38 } 39 40 // A thin wrapper around llvm_util.h to make code generating vector math flow 41 // more readable. 42 class VectorSupportLibrary { 43 public: 44 // This VectorSupportLibrary instance remembers `primitive_type` and 45 // `vector_size`, and these are implicitly used by the methods on this 46 // instance (i.e. LoadVector will load a vector of type <`vector_size` x 47 // `primitive_type`>). 48 VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, 49 llvm::IRBuilder<>* b, std::string name); 50 51 llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); 52 llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { 53 return Mul(b()->getInt64(lhs), rhs); 54 } 55 llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { 56 return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); 57 } 58 59 // If your call resolved to these then you probably wanted the versions taking 60 // APFloat. 61 llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete; 62 llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete; 63 64 llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); 65 llvm::Value* Add(int64 lhs, llvm::Value* rhs) { 66 return Add(b()->getInt64(lhs), rhs); 67 } 68 llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { 69 return Add(GetConstantFloat(rhs->getType(), lhs), rhs); 70 } 71 72 // If your call resolved to these then you probably wanted the versions taking 73 // APFloat. 74 llvm::Value* Add(double lhs, llvm::Value* rhs) = delete; 75 llvm::Value* Add(float lhs, llvm::Value* rhs) = delete; 76 77 llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs); 78 llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { 79 return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); 80 } 81 llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); 82 llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { 83 return Max(GetConstantFloat(rhs->getType(), lhs), rhs); 84 } 85 llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); 86 87 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { 88 return Add(c, Mul(a, b)); 89 } 90 91 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) { 92 return Add(GetConstantFloat(vector_type(), c), Mul(a, b)); 93 } 94 95 llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b, 96 const llvm::APFloat& c) { 97 return Add(GetConstantFloat(a->getType(), c), 98 Mul(a, GetConstantFloat(a->getType(), b))); 99 } 100 101 llvm::Value* Floor(llvm::Value* a); 102 103 llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low, 104 const llvm::APFloat& high); 105 llvm::Value* SplatFloat(const llvm::APFloat& d) { 106 return GetConstantFloat(vector_type(), d); 107 } 108 109 // These compare instructions return a floating point typed mask instead of an 110 // i1. For instance, on a vector typed input, lanes where the predicate is 111 // true get a float with all ones and other lanes get a float with all zeros. 112 // This is slightly odd from the perspective of LLVM's type system, but it 113 // makes kernel IR generation code written using VectorSupportLibrary (its 114 // raison d'etre) less cluttered. 115 116 llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); 117 llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) { 118 return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs)); 119 } 120 llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); 121 llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); 122 llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { 123 return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs)); 124 } 125 126 // These boolean operations operate on the bitwise values of the floating 127 // point inputs. They return a (vector of) float(s) but like in the mask 128 // generating predicates above this type system oddity makes the kernel IR 129 // generation code less cluttered. 130 llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs); 131 llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) { 132 return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs)); 133 } 134 llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs); 135 llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) { 136 return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs)); 137 } 138 llvm::Value* FloatNot(llvm::Value* lhs); 139 llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) { 140 return FloatAnd(FloatNot(lhs), rhs); 141 } 142 143 llvm::Value* BroadcastScalar(llvm::Value* x); 144 llvm::Value* BroadcastScalar(const llvm::APFloat& d) { 145 return BroadcastScalar(GetConstantFloat(scalar_type(), d)); 146 } 147 148 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, 149 llvm::Value* offset_elements); 150 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, 151 llvm::Value* offset_elements, int64 scale) { 152 return ComputeOffsetPointer( 153 base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements)); 154 } 155 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, 156 int64 offset_elements) { 157 return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements)); 158 } 159 160 llvm::Value* LoadVector(llvm::Value* pointer); 161 162 llvm::Value* LoadVector(llvm::Value* base_pointer, 163 llvm::Value* offset_elements) { 164 return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); 165 } 166 167 llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { 168 return LoadVector(base_pointer, b()->getInt64(offset_elements)); 169 } 170 171 llvm::Value* LoadScalar(llvm::Value* pointer); 172 173 llvm::Value* LoadScalar(llvm::Value* base_pointer, 174 llvm::Value* offset_elements) { 175 return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); 176 } 177 178 llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { 179 return LoadScalar(base_pointer, b()->getInt64(offset_elements)); 180 } 181 182 void StoreVector(llvm::Value* value, llvm::Value* pointer); 183 184 void StoreVector(llvm::Value* value, llvm::Value* base_pointer, 185 llvm::Value* offset_elements) { 186 StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); 187 } 188 189 void StoreVector(llvm::Value* value, llvm::Value* base_pointer, 190 int64 offset_elements) { 191 StoreVector(value, base_pointer, b()->getInt64(offset_elements)); 192 } 193 194 void StoreScalar(llvm::Value* value, llvm::Value* pointer); 195 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, 196 llvm::Value* offset_elements) { 197 StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); 198 } 199 200 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, 201 int64 offset_elements) { 202 StoreScalar(base_pointer, b()->getInt64(offset_elements)); 203 } 204 205 llvm::Value* LoadBroadcast(llvm::Value* pointer); 206 llvm::Value* LoadBroadcast(llvm::Value* base_pointer, 207 llvm::Value* offset_elements) { 208 return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); 209 } 210 llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { 211 return LoadBroadcast(base_pointer, b()->getInt64(offset_elements)); 212 } 213 214 // Compute the horizontal sum of each vector in `vectors`. The i'th element 215 // in the result vector is the (scalar) horizontal sum of the i'th vector in 216 // `vectors`. If `init_values` is not nullptr then the value in the i'th lane 217 // in `init_values` is added to the i'th horizontal sum. 218 std::vector<llvm::Value*> ComputeHorizontalSums( 219 std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr); 220 221 llvm::Value* GetZeroVector(); 222 llvm::Value* GetZeroScalar(); 223 224 llvm::IRBuilder<>* b() const { return b_; } 225 int64 vector_size() const { return vector_size_; } 226 llvm::Type* vector_type() const { return vector_type_; } 227 llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } 228 llvm::Type* scalar_type() const { return scalar_type_; } 229 llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } 230 int64 scalar_byte_size() const { 231 return primitive_util::BitWidth(primitive_type_) / 8; 232 } 233 234 const std::string& name() const { return name_; } 235 236 private: 237 llvm::Value* ExtractLowHalf(llvm::Value*); 238 llvm::Value* ExtractHighHalf(llvm::Value*); 239 240 llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); 241 llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); 242 243 llvm::Value* AddReduce(llvm::Value* vector); 244 245 // Checks that each value in `values` is either of type scalar_type() or 246 // vector_type(). This LOG(FATAL)'s so it should only be called in cases 247 // where a mismatching type is a programmer bug. 248 void AssertCorrectTypes(std::initializer_list<llvm::Value*> values); 249 250 // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The 251 // resulting IR for an 8-float wide vector is expected to lower to a single 252 // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in 253 // other cases. 254 // 255 // For a vector width of 8, the result vector is computed as: 256 // Result[0] = Lhs[0] + Lhs[1] 257 // Result[1] = Lhs[2] + Lhs[3] 258 // Result[2] = Rhs[0] + Rhs[1] 259 // Result[3] = Rhs[2] + Rhs[3] 260 // Result[4] = Lhs[4] + Lhs[5] 261 // Result[5] = Lhs[6] + Lhs[7] 262 // Result[6] = Rhs[4] + Rhs[5] 263 // Result[7] = Rhs[6] + Rhs[7] 264 llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); 265 266 std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums( 267 std::vector<llvm::Value*> vectors, llvm::Value* init_values); 268 269 llvm::Type* IntegerTypeForFloatSize(bool vector); 270 llvm::Value* I1ToFloat(llvm::Value* i1); 271 llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) { 272 llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); 273 if (llvm::isa<llvm::VectorType>(type)) { 274 return llvm::ConstantVector::getSplat(vector_size(), scalar_value); 275 } 276 return scalar_value; 277 } 278 279 int64 vector_size_; 280 PrimitiveType primitive_type_; 281 llvm::IRBuilder<>* b_; 282 llvm::Type* vector_type_; 283 llvm::Type* vector_pointer_type_; 284 llvm::Type* scalar_type_; 285 llvm::Type* scalar_pointer_type_; 286 std::string name_; 287 }; 288 289 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass 290 // can later convert to a SSA value. 291 class LlvmVariable { 292 public: 293 LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b); 294 295 llvm::Value* Get() const; 296 void Set(llvm::Value* new_value); 297 298 private: 299 llvm::AllocaInst* alloca_; 300 llvm::IRBuilder<>* b_; 301 }; 302 303 class VectorVariable : public LlvmVariable { 304 public: 305 VectorVariable(VectorSupportLibrary* vector_support, 306 llvm::Value* initial_value) 307 : LlvmVariable(vector_support->vector_type(), vector_support->b()) { 308 Set(initial_value); 309 } 310 }; 311 312 class ScalarVariable : public LlvmVariable { 313 public: 314 ScalarVariable(VectorSupportLibrary* vector_support, 315 llvm::Value* initial_value) 316 : LlvmVariable(vector_support->scalar_type(), vector_support->b()) { 317 Set(initial_value); 318 } 319 }; 320 321 // This wraps a set of alloca-backed stack variables that can, as a whole, store 322 // a tile. A "tile" is a sequence of vectors that is typically used as a 2D 323 // grid of scalar values (e.g. for tiled GEMMs). 324 class TileVariable { 325 public: 326 TileVariable(VectorSupportLibrary* vector_support, 327 std::vector<llvm::Value*> initial_value); 328 329 std::vector<llvm::Value*> Get() const; 330 void Set(absl::Span<llvm::Value* const> value); 331 332 private: 333 std::vector<VectorVariable> storage_; 334 }; 335 } // namespace cpu 336 } // namespace xla 337 338 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ 339