Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
     18 
     19 #include <unordered_map>
     20 
     21 #include "llvm/IR/IRBuilder.h"
     22 #include "llvm/IR/Value.h"
     23 #include "tensorflow/compiler/xla/map_util.h"
     24 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
     27 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     28 #include "tensorflow/core/lib/gtl/array_slice.h"
     29 
     30 namespace xla {
     31 namespace gpu {
     32 
     33 // This class encapsulates the bindings between HloInstructions and LLVM IR
     34 // values that represent their addresses.
     35 class HloToIrBindings {
     36  public:
     37   HloToIrBindings(const HloModule& module,
     38                   const BufferAssignment* buffer_assignment,
     39                   llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module,
     40                   bool is_nested)
     41       : buffer_assignment_(buffer_assignment),
     42         is_nested_(is_nested),
     43         ir_builder_(ir_builder),
     44         module_(llvm_module),
     45         alias_analysis_(module, *buffer_assignment_,
     46                         &ir_builder_->getContext()) {}
     47 
     48   void EmitBasePointersForHlos(
     49       tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
     50       tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);
     51 
     52   // Rebinds the given HLO to the LLVM IR value that represent its address.
     53   void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
     54                         const ShapeIndex& shape_index = {});
     55 
     56   // Unbinds all IR values that's defined in an LLVM function, e.g., function
     57   // arguments and stack variables. Global variables will be kept in bindings_.
     58   //
     59   // This method is called after emitting code for each top-level HLO. The local
     60   // IR values are out of scope at that point and should not be used.
     61   void UnbindAllLocalIrValues();
     62 
     63   // Returns whether `hlo` is bound to an LLVM IR value.
     64   bool BoundToIrValue(const HloInstruction& hlo) const {
     65     return base_ptrs_.count(&hlo);
     66   }
     67 
     68   llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; }
     69   void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; }
     70 
     71   // A helper method that returns the base pointer of the IrArray containing the
     72   // output of "inst".at the given ShapeIndex.
     73   llvm::Value* GetBasePointer(const HloInstruction& hlo,
     74                               const ShapeIndex& shape_index = {}) const {
     75     auto it = base_ptrs_.find(&hlo);
     76     CHECK(it != base_ptrs_.end()) << hlo.ToString();
     77     return it->second.element(shape_index);
     78   }
     79 
     80   // Returns the IrArray which contains the output of hlo.
     81   //
     82   // consumer is the HLO in which this IrArray is used -- we use this to (try
     83   // to) add metadata indicating that the array is invariant within consumer.
     84   //
     85   // To get the buffer into which hlo should write its own output, call
     86   // GetIrArray(hlo, hlo).
     87   llvm_ir::IrArray GetIrArray(const HloInstruction& hlo,
     88                               const HloInstruction& consumer,
     89                               const ShapeIndex& shape_index = {});
     90 
     91   string ToString() const;
     92 
     93  private:
     94   // Emits IR to resolve (possibly) recursive GetTupleElement instructions.
     95   llvm::Value* EmitGetTupleElement(const HloInstruction* gte,
     96                                    llvm::Value* base_ptr);
     97 
     98   // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape.
     99   llvm::Value* GetTypedIrValue(const HloInstruction& hlo,
    100                                const ShapeIndex& shape_index,
    101                                llvm::Value* ir_value);
    102 
    103   const BufferAssignment* buffer_assignment_;
    104 
    105   const bool is_nested_;
    106 
    107   llvm::IRBuilder<>* ir_builder_;
    108   llvm::Module* module_;
    109 
    110   // Stores the underlying llvm::IrArray for each HloInstruction.
    111   // For an instruction that generates multiple outputs, the root will be a
    112   // tuple shape. The IrArray for each element output is stored in the subnode
    113   // in the ShapeTree.
    114   std::unordered_map<const HloInstruction*, ShapeTree<llvm::Value*>> base_ptrs_;
    115 
    116   // The address of the memory block that contains all temporary buffers.
    117   llvm::Value* temp_buffer_base_ = nullptr;
    118 
    119   llvm_ir::AliasAnalysis alias_analysis_;
    120 };
    121 
    122 }  // namespace gpu
    123 }  // namespace xla
    124 
    125 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
    126