Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
     18 
     19 #include <unordered_map>
     20 
     21 #include "llvm/IR/IRBuilder.h"
     22 #include "llvm/IR/Module.h"
     23 #include "llvm/IR/Value.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     26 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 
     29 namespace xla {
     30 
     31 class ElementalIrEmitter {
     32  public:
     33   using HloToElementGeneratorMap =
     34       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
     35 
     36   ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
     37                      llvm::Module* module, llvm::IRBuilder<>* ir_builder)
     38       : ir_builder_(ir_builder),
     39         module_(module),
     40         hlo_module_config_(hlo_module_config) {}
     41 
     42   virtual ~ElementalIrEmitter() = default;
     43 
     44   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
     45                                              llvm::Value* operand_value) const;
     46 
     47   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
     48                                               llvm::Value* lhs_value,
     49                                               llvm::Value* rhs_value) const;
     50 
     51   // Returns a function to generate an element of the output of `hlo`, given a
     52   // map of functions to generate elements of its operands.
     53   virtual llvm_ir::ElementGenerator MakeElementGenerator(
     54       const HloInstruction* hlo,
     55       const HloToElementGeneratorMap& operand_to_generator) const;
     56 
     57   llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
     58   llvm::Module* module() const { return module_; }
     59 
     60  protected:
     61   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
     62       const HloInstruction* op, llvm::Value* operand_value) const;
     63 
     64   virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
     65       const HloInstruction* op, llvm::Value* operand_value) const;
     66 
     67   virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
     68       const HloInstruction* op, llvm::Value* operand_value) const;
     69 
     70   virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
     71                                                      llvm::Value* lhs_value,
     72                                                      llvm::Value* rhs_value,
     73                                                      bool is_signed) const;
     74 
     75   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
     76       const HloInstruction* op, llvm::Value* lhs_value,
     77       llvm::Value* rhs_value) const;
     78 
     79   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
     80       const HloInstruction* op, llvm::Value* lhs_value,
     81       llvm::Value* rhs_value) const;
     82 
     83   virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
     84                                     llvm::Value* rhs_value) const;
     85 
     86   virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
     87                                     llvm::Value* rhs_value) const;
     88 
     89   llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
     90                                bool is_signed) const;
     91 
     92   llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
     93                                bool is_signed) const;
     94 
     95   virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
     96                                             llvm::Value* value) const;
     97 
     98   virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
     99                                              llvm::Value* value) const;
    100 
    101   virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
    102                                            llvm::Value* lhs,
    103                                            llvm::Value* rhs) const;
    104 
    105   virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
    106                                          llvm::Value* value) const;
    107 
    108   virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
    109                                          llvm::Value* value) const;
    110 
    111   virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
    112                                          llvm::Value* value) const;
    113 
    114   virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
    115                                          llvm::Value* value) const;
    116 
    117   virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
    118                                          llvm::Value* lhs,
    119                                          llvm::Value* rhs) const;
    120 
    121   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
    122                                                      llvm::Value* x) const;
    123 
    124   virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
    125   virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
    126 
    127   // Composes a complex struct. imag may be nullptr for simple cast operations.
    128   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
    129                                   llvm::Value* imag) const;
    130 
    131   // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
    132   // the target array index, computes the source array index of its
    133   // `operand_no`-th operand.
    134   //
    135   // Precondition: `hlo` is an elementwise op.
    136   llvm_ir::IrArray::Index ElementwiseSourceIndex(
    137       const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
    138       int64 operand_no) const;
    139 
    140   // Identifier of the thread unique among all threads on the device
    141   virtual llvm::Value* EmitThreadId() const {
    142     return ir_builder_->getIntN(128, 0);
    143   }
    144 
    145   llvm::IRBuilder<>* const ir_builder_;
    146 
    147   llvm::Module* module_;
    148 
    149   // The HloModuleConfig which gathers all settings and values which affect the
    150   // compiled executable outside of the HLO code itself.
    151   const HloModuleConfig& hlo_module_config_;
    152 
    153  private:
    154   // Returns a ElementGenerator for a RNG HloInstruction.
    155   llvm_ir::ElementGenerator MakeRngElementGenerator(
    156       const HloInstruction* hlo,
    157       const HloToElementGeneratorMap& operand_to_generator) const;
    158 };
    159 
    160 }  // namespace xla
    161 
    162 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
    163