Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
     18 
     19 #include <string>
     20 
     21 #include "absl/types/span.h"
     22 #include "llvm/IR/IRBuilder.h"
     23 #include "llvm/IR/Value.h"
     24 #include "tensorflow/compiler/xla/primitive_util.h"
     25 #include "tensorflow/compiler/xla/types.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 
     28 namespace xla {
     29 namespace cpu {
     30 
     31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
     32 // obvious.
     33 
     34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
     35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
     36   return llvm::APFloat(llvm::APFloat::IEEEsingle(),
     37                        llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
     38 }
     39 
     40 // A thin wrapper around llvm_util.h to make code generating vector math flow
     41 // more readable.
     42 class VectorSupportLibrary {
     43  public:
     44   // This VectorSupportLibrary instance remembers `primitive_type` and
     45   // `vector_size`, and these are implicitly used by the methods on this
     46   // instance (i.e. LoadVector will load a vector of type <`vector_size` x
     47   // `primitive_type`>).
     48   VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
     49                        llvm::IRBuilder<>* b, std::string name);
     50 
     51   llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
     52   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
     53     return Mul(b()->getInt64(lhs), rhs);
     54   }
     55   llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
     56     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
     57   }
     58 
     59   // If your call resolved to these then you probably wanted the versions taking
     60   // APFloat.
     61   llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
     62   llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
     63 
     64   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
     65   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
     66     return Add(b()->getInt64(lhs), rhs);
     67   }
     68   llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
     69     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
     70   }
     71 
     72   // If your call resolved to these then you probably wanted the versions taking
     73   // APFloat.
     74   llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
     75   llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
     76 
     77   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
     78   llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
     79     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
     80   }
     81   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
     82   llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
     83     return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
     84   }
     85   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
     86 
     87   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
     88     return Add(c, Mul(a, b));
     89   }
     90 
     91   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
     92     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
     93   }
     94 
     95   llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
     96                       const llvm::APFloat& c) {
     97     return Add(GetConstantFloat(a->getType(), c),
     98                Mul(a, GetConstantFloat(a->getType(), b)));
     99   }
    100 
    101   llvm::Value* Floor(llvm::Value* a);
    102 
    103   llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
    104                      const llvm::APFloat& high);
    105   llvm::Value* SplatFloat(const llvm::APFloat& d) {
    106     return GetConstantFloat(vector_type(), d);
    107   }
    108 
    109   // These compare instructions return a floating point typed mask instead of an
    110   // i1.  For instance, on a vector typed input, lanes where the predicate is
    111   // true get a float with all ones and other lanes get a float with all zeros.
    112   // This is slightly odd from the perspective of LLVM's type system, but it
    113   // makes kernel IR generation code written using VectorSupportLibrary (its
    114   // raison d'etre) less cluttered.
    115 
    116   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
    117   llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
    118     return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
    119   }
    120   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
    121   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
    122   llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
    123     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
    124   }
    125 
    126   // These boolean operations operate on the bitwise values of the floating
    127   // point inputs.  They return a (vector of) float(s) but like in the mask
    128   // generating predicates above this type system oddity makes the kernel IR
    129   // generation code less cluttered.
    130   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
    131   llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
    132     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
    133   }
    134   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
    135   llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
    136     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
    137   }
    138   llvm::Value* FloatNot(llvm::Value* lhs);
    139   llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
    140     return FloatAnd(FloatNot(lhs), rhs);
    141   }
    142 
    143   llvm::Value* BroadcastScalar(llvm::Value* x);
    144   llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
    145     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
    146   }
    147 
    148   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
    149                                     llvm::Value* offset_elements);
    150   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
    151                                     llvm::Value* offset_elements, int64 scale) {
    152     return ComputeOffsetPointer(
    153         base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
    154   }
    155   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
    156                                     int64 offset_elements) {
    157     return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
    158   }
    159 
    160   llvm::Value* LoadVector(llvm::Value* pointer);
    161 
    162   llvm::Value* LoadVector(llvm::Value* base_pointer,
    163                           llvm::Value* offset_elements) {
    164     return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
    165   }
    166 
    167   llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
    168     return LoadVector(base_pointer, b()->getInt64(offset_elements));
    169   }
    170 
    171   llvm::Value* LoadScalar(llvm::Value* pointer);
    172 
    173   llvm::Value* LoadScalar(llvm::Value* base_pointer,
    174                           llvm::Value* offset_elements) {
    175     return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
    176   }
    177 
    178   llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
    179     return LoadScalar(base_pointer, b()->getInt64(offset_elements));
    180   }
    181 
    182   void StoreVector(llvm::Value* value, llvm::Value* pointer);
    183 
    184   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
    185                    llvm::Value* offset_elements) {
    186     StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
    187   }
    188 
    189   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
    190                    int64 offset_elements) {
    191     StoreVector(value, base_pointer, b()->getInt64(offset_elements));
    192   }
    193 
    194   void StoreScalar(llvm::Value* value, llvm::Value* pointer);
    195   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
    196                    llvm::Value* offset_elements) {
    197     StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
    198   }
    199 
    200   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
    201                    int64 offset_elements) {
    202     StoreScalar(base_pointer, b()->getInt64(offset_elements));
    203   }
    204 
    205   llvm::Value* LoadBroadcast(llvm::Value* pointer);
    206   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
    207                              llvm::Value* offset_elements) {
    208     return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
    209   }
    210   llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
    211     return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
    212   }
    213 
    214   // Compute the horizontal sum of each vector in `vectors`.  The i'th element
    215   // in the result vector is the (scalar) horizontal sum of the i'th vector in
    216   // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
    217   // in `init_values` is added to the i'th horizontal sum.
    218   std::vector<llvm::Value*> ComputeHorizontalSums(
    219       std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
    220 
    221   llvm::Value* GetZeroVector();
    222   llvm::Value* GetZeroScalar();
    223 
    224   llvm::IRBuilder<>* b() const { return b_; }
    225   int64 vector_size() const { return vector_size_; }
    226   llvm::Type* vector_type() const { return vector_type_; }
    227   llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
    228   llvm::Type* scalar_type() const { return scalar_type_; }
    229   llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
    230   int64 scalar_byte_size() const {
    231     return primitive_util::BitWidth(primitive_type_) / 8;
    232   }
    233 
    234   const std::string& name() const { return name_; }
    235 
    236  private:
    237   llvm::Value* ExtractLowHalf(llvm::Value*);
    238   llvm::Value* ExtractHighHalf(llvm::Value*);
    239 
    240   llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
    241   llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
    242 
    243   llvm::Value* AddReduce(llvm::Value* vector);
    244 
    245   // Checks that each value in `values` is either of type scalar_type() or
    246   // vector_type().  This LOG(FATAL)'s so it should only be called in cases
    247   // where a mismatching type is a programmer bug.
    248   void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
    249 
    250   // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
    251   // resulting IR for an 8-float wide vector is expected to lower to a single
    252   // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
    253   // other cases.
    254   //
    255   // For a vector width of 8, the result vector is computed as:
    256   //   Result[0] = Lhs[0] + Lhs[1]
    257   //   Result[1] = Lhs[2] + Lhs[3]
    258   //   Result[2] = Rhs[0] + Rhs[1]
    259   //   Result[3] = Rhs[2] + Rhs[3]
    260   //   Result[4] = Lhs[4] + Lhs[5]
    261   //   Result[5] = Lhs[6] + Lhs[7]
    262   //   Result[6] = Rhs[4] + Rhs[5]
    263   //   Result[7] = Rhs[6] + Rhs[7]
    264   llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
    265 
    266   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
    267       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
    268 
    269   llvm::Type* IntegerTypeForFloatSize(bool vector);
    270   llvm::Value* I1ToFloat(llvm::Value* i1);
    271   llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
    272     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
    273     if (llvm::isa<llvm::VectorType>(type)) {
    274       return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
    275     }
    276     return scalar_value;
    277   }
    278 
    279   int64 vector_size_;
    280   PrimitiveType primitive_type_;
    281   llvm::IRBuilder<>* b_;
    282   llvm::Type* vector_type_;
    283   llvm::Type* vector_pointer_type_;
    284   llvm::Type* scalar_type_;
    285   llvm::Type* scalar_pointer_type_;
    286   std::string name_;
    287 };
    288 
    289 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
    290 // can later convert to a SSA value.
    291 class LlvmVariable {
    292  public:
    293   LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
    294 
    295   llvm::Value* Get() const;
    296   void Set(llvm::Value* new_value);
    297 
    298  private:
    299   llvm::AllocaInst* alloca_;
    300   llvm::IRBuilder<>* b_;
    301 };
    302 
    303 class VectorVariable : public LlvmVariable {
    304  public:
    305   VectorVariable(VectorSupportLibrary* vector_support,
    306                  llvm::Value* initial_value)
    307       : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
    308     Set(initial_value);
    309   }
    310 };
    311 
    312 class ScalarVariable : public LlvmVariable {
    313  public:
    314   ScalarVariable(VectorSupportLibrary* vector_support,
    315                  llvm::Value* initial_value)
    316       : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
    317     Set(initial_value);
    318   }
    319 };
    320 
    321 // This wraps a set of alloca-backed stack variables that can, as a whole, store
    322 // a tile.  A "tile" is a sequence of vectors that is typically used as a 2D
    323 // grid of scalar values (e.g. for tiled GEMMs).
    324 class TileVariable {
    325  public:
    326   TileVariable(VectorSupportLibrary* vector_support,
    327                std::vector<llvm::Value*> initial_value);
    328 
    329   std::vector<llvm::Value*> Get() const;
    330   void Set(absl::Span<llvm::Value* const> value);
    331 
    332  private:
    333   std::vector<VectorVariable> storage_;
    334 };
    335 }  // namespace cpu
    336 }  // namespace xla
    337 
    338 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
    339