Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
     18 
     19 #include <string>
     20 
     21 #include "llvm/IR/IRBuilder.h"
     22 #include "llvm/IR/Value.h"
     23 #include "tensorflow/compiler/xla/primitive_util.h"
     24 #include "tensorflow/compiler/xla/types.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 
     27 namespace xla {
     28 namespace cpu {
     29 
     30 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
     31 // obvious.
     32 
     33 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
     34 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
     35   return llvm::APFloat(llvm::APFloat::IEEEsingle(),
     36                        llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
     37 }
     38 
     39 // A thin wrapper around llvm_util.h to make code generating vector math flow
     40 // more readable.
     41 class VectorSupportLibrary {
     42  public:
     43   // This VectorSupportLibrary instance remembers `primitive_type` and
     44   // `vector_size`, and these are implicitly used by the methods on this
     45   // instance (i.e. LoadVector will load a vector of type <`vector_size` x
     46   // `primitive_type`>).
     47   VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
     48                        llvm::IRBuilder<>* ir_builder, std::string name);
     49 
     50   llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
     51   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
     52     return Mul(ir_builder()->getInt64(lhs), rhs);
     53   }
     54   llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
     55     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
     56   }
     57 
     58   // If your call resolved to these then you probably wanted the versions taking
     59   // APFloat.
     60   llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
     61   llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
     62 
     63   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
     64   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
     65     return Add(ir_builder()->getInt64(lhs), rhs);
     66   }
     67   llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
     68     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
     69   }
     70 
     71   // If your call resolved to these then you probably wanted the versions taking
     72   // APFloat.
     73   llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
     74   llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
     75 
     76   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
     77   llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
     78     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
     79   }
     80   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
     81   llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
     82     return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
     83   }
     84   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
     85 
     86   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
     87     return Add(c, Mul(a, b));
     88   }
     89 
     90   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
     91     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
     92   }
     93 
     94   llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
     95                       const llvm::APFloat& c) {
     96     return Add(GetConstantFloat(a->getType(), c),
     97                Mul(a, GetConstantFloat(a->getType(), b)));
     98   }
     99 
    100   llvm::Value* Floor(llvm::Value* a);
    101 
    102   llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
    103                      const llvm::APFloat& high);
    104   llvm::Value* SplatFloat(const llvm::APFloat& d) {
    105     return GetConstantFloat(vector_type(), d);
    106   }
    107 
    108   // These compare instructions return a floating point typed mask instead of an
    109   // i1.  For instance, on a vector typed input, lanes where the predicate is
    110   // true get a float with all ones and other lanes get a float with all zeros.
    111   // This is slightly odd from the perspective of LLVM's type system, but it
    112   // makes kernel IR generation code written using VectorSupportLibrary (its
    113   // raison d'etre) less cluttered.
    114 
    115   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
    116   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
    117   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
    118   llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
    119     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
    120   }
    121 
    122   // These boolean operations operate on the bitwise values of the floating
    123   // point inputs.  They return a (vector of) float(s) but like in the mask
    124   // generating predicates above this type system oddity makes the kernel IR
    125   // generation code less cluttered.
    126   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
    127   llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
    128     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
    129   }
    130   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
    131   llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
    132     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
    133   }
    134   llvm::Value* FloatNot(llvm::Value* lhs);
    135   llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
    136     return FloatAnd(FloatNot(lhs), rhs);
    137   }
    138 
    139   llvm::Value* BroadcastScalar(llvm::Value* x);
    140   llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
    141     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
    142   }
    143 
    144   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
    145                                     llvm::Value* offset_elements);
    146   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
    147                                     int64 offset_elements) {
    148     return ComputeOffsetPointer(base_pointer,
    149                                 ir_builder()->getInt64(offset_elements));
    150   }
    151 
    152   llvm::Value* LoadVector(llvm::Value* pointer);
    153 
    154   llvm::Value* LoadVector(llvm::Value* base_pointer,
    155                           llvm::Value* offset_elements) {
    156     return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
    157   }
    158 
    159   llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
    160     return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
    161   }
    162 
    163   llvm::Value* LoadScalar(llvm::Value* pointer);
    164 
    165   llvm::Value* LoadScalar(llvm::Value* base_pointer,
    166                           llvm::Value* offset_elements) {
    167     return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
    168   }
    169 
    170   llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
    171     return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
    172   }
    173 
    174   void StoreVector(llvm::Value* value, llvm::Value* pointer);
    175 
    176   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
    177                    llvm::Value* offset_elements) {
    178     StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
    179   }
    180 
    181   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
    182                    int64 offset_elements) {
    183     StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
    184   }
    185 
    186   void StoreScalar(llvm::Value* value, llvm::Value* pointer);
    187   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
    188                    llvm::Value* offset_elements) {
    189     StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
    190   }
    191 
    192   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
    193                    int64 offset_elements) {
    194     StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
    195   }
    196 
    197   llvm::Value* LoadBroadcast(llvm::Value* pointer);
    198   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
    199                              llvm::Value* offset_elements) {
    200     return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
    201   }
    202   llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
    203     return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
    204   }
    205 
    206   // Compute the horizontal sum of each vector in `vectors`.  The i'th element
    207   // in the result vector is the (scalar) horizontal sum of the i'th vector in
    208   // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
    209   // in `init_values` is added to the i'th horizontal sum.
    210   std::vector<llvm::Value*> ComputeHorizontalSums(
    211       std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
    212 
    213   llvm::Value* GetZeroVector();
    214   llvm::Value* GetZeroScalar();
    215 
    216   llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
    217   int64 vector_size() const { return vector_size_; }
    218   llvm::Type* vector_type() const { return vector_type_; }
    219   llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
    220   llvm::Type* scalar_type() const { return scalar_type_; }
    221   llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
    222   int64 scalar_byte_size() const {
    223     return primitive_util::BitWidth(primitive_type_) / 8;
    224   }
    225 
    226   const std::string& name() const { return name_; }
    227 
    228  private:
    229   llvm::Value* ExtractLowHalf(llvm::Value*);
    230   llvm::Value* ExtractHighHalf(llvm::Value*);
    231 
    232   llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
    233   llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
    234 
    235   llvm::Value* AddReduce(llvm::Value* vector);
    236 
    237   // Checks that each value in `values` is either of type scalar_type() or
    238   // vector_type().  This LOG(FATAL)'s so it should only be called in cases
    239   // where a mismatching type is a programmer bug.
    240   void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
    241 
    242   // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
    243   // resulting IR for an 8-float wide vector is expected to lower to a single
    244   // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
    245   // other cases.
    246   //
    247   // For a vector width of 8, the result vector is computed as:
    248   //   Result[0] = Lhs[0] + Lhs[1]
    249   //   Result[1] = Lhs[2] + Lhs[3]
    250   //   Result[2] = Rhs[0] + Rhs[1]
    251   //   Result[3] = Rhs[2] + Rhs[3]
    252   //   Result[4] = Lhs[4] + Lhs[5]
    253   //   Result[5] = Lhs[6] + Lhs[7]
    254   //   Result[6] = Rhs[4] + Rhs[5]
    255   //   Result[7] = Rhs[6] + Rhs[7]
    256   llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
    257 
    258   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
    259       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
    260 
    261   llvm::Type* IntegerTypeForFloatSize(bool vector);
    262   llvm::Value* I1ToFloat(llvm::Value* i1);
    263   llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
    264     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
    265     if (llvm::isa<llvm::VectorType>(type)) {
    266       return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
    267     }
    268     return scalar_value;
    269   }
    270 
    271   int64 vector_size_;
    272   PrimitiveType primitive_type_;
    273   llvm::IRBuilder<>* ir_builder_;
    274   llvm::Type* vector_type_;
    275   llvm::Type* vector_pointer_type_;
    276   llvm::Type* scalar_type_;
    277   llvm::Type* scalar_pointer_type_;
    278   std::string name_;
    279 };
    280 
    281 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
    282 // can later convert to a SSA value.
    283 class LlvmVariable {
    284  public:
    285   LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
    286 
    287   llvm::Value* Get() const;
    288   void Set(llvm::Value* new_value);
    289 
    290  private:
    291   llvm::AllocaInst* alloca_;
    292   llvm::IRBuilder<>* ir_builder_;
    293 };
    294 
    295 class VectorVariable : public LlvmVariable {
    296  public:
    297   VectorVariable(VectorSupportLibrary* vector_support,
    298                  llvm::Value* initial_value)
    299       : LlvmVariable(vector_support->vector_type(),
    300                      vector_support->ir_builder()) {
    301     Set(initial_value);
    302   }
    303 };
    304 
    305 class ScalarVariable : public LlvmVariable {
    306  public:
    307   ScalarVariable(VectorSupportLibrary* vector_support,
    308                  llvm::Value* initial_value)
    309       : LlvmVariable(vector_support->scalar_type(),
    310                      vector_support->ir_builder()) {
    311     Set(initial_value);
    312   }
    313 };
    314 }  // namespace cpu
    315 }  // namespace xla
    316 
    317 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
    318