Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     17 
     18 #include "llvm/IR/BasicBlock.h"
     19 #include "llvm/IR/Function.h"
     20 #include "llvm/IR/Instructions.h"
     21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     24 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 namespace gpu {
     32 
     33 using tensorflow::strings::StrAppend;
     34 using tensorflow::strings::StrCat;
     35 
     36 void HloToIrBindings::EmitBasePointersForHlos(
     37     tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
     38     tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) {
     39   // I/O HLOs are bound to the arguments of the current IR function. I.e.,
     40   //
     41   // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
     42   llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
     43   CHECK_EQ(io_hlos.size() + 1, function->arg_size());
     44 
     45   // An HLO can have duplicated operands. This data structure remembers which
     46   // operand HLOs are already bound to avoid rebinding the same HLO.
     47   std::set<const HloInstruction*> already_bound_for_this_function;
     48   auto arg_iter = function->arg_begin();
     49   for (const HloInstruction* io_hlo : io_hlos) {
     50     if (!already_bound_for_this_function.count(io_hlo)) {
     51       if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
     52         BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
     53       } else {
     54         BindHloToIrValue(*io_hlo, &*arg_iter);
     55       }
     56       already_bound_for_this_function.insert(io_hlo);
     57     }
     58     ++arg_iter;
     59   }
     60 
     61   temp_buffer_base_ = &*arg_iter;
     62   temp_buffer_base_->setName("temp_buffer");
     63 
     64   for (const HloInstruction* non_io_hlo : non_io_hlos) {
     65     if (already_bound_for_this_function.count(non_io_hlo)) {
     66       continue;
     67     }
     68     already_bound_for_this_function.insert(non_io_hlo);
     69 
     70     if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
     71       if (!is_nested_) {
     72         // Lookup allocation GetTupleElement operand.
     73         const BufferAllocation::Slice slice =
     74             buffer_assignment_
     75                 ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor())
     76                 .ConsumeValueOrDie();
     77         // We are not in a nested context, so check non-thread-local allocation.
     78         CHECK(!slice.allocation()->is_thread_local());
     79         const int64 offset = slice.offset();
     80         CHECK_NE(nullptr, temp_buffer_base_);
     81         // Emit IR for GetTupleElement instruction and bind to emitted value.
     82         llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP(
     83             temp_buffer_base_, ir_builder_->getInt64(offset));
     84         BindHloToIrValue(*non_io_hlo,
     85                          EmitGetTupleElement(non_io_hlo, base_ptr));
     86       }
     87       continue;
     88     }
     89 
     90     if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
     91       continue;
     92     }
     93 
     94     ShapeUtil::ForEachSubshape(
     95         non_io_hlo->shape(),
     96         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
     97           // A non-IO HLO with a buffer is bound to
     98           // (1) an alloca if it is thread-local, or
     99           // (2) an internal pointer in temp_buffer_base according to its
    100           // offset.
    101           auto slice_result =
    102               buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
    103           if (!slice_result.ok()) {
    104             return;
    105           }
    106           const BufferAllocation::Slice slice =
    107               slice_result.ConsumeValueOrDie();
    108           if (slice.allocation()->is_thread_local()) {
    109             llvm::Type* pointee_type =
    110                 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
    111             BindHloToIrValue(*non_io_hlo,
    112                              ir_builder_->CreateAlloca(pointee_type), index);
    113           } else {
    114             const int64 offset = slice.offset();
    115             CHECK_NE(nullptr, temp_buffer_base_);
    116             BindHloToIrValue(
    117                 *non_io_hlo,
    118                 ir_builder_->CreateInBoundsGEP(temp_buffer_base_,
    119                                                ir_builder_->getInt64(offset)),
    120                 index);
    121           }
    122         });
    123   }
    124 }
    125 
    126 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
    127                                                   llvm::Value* base_ptr) {
    128   // TODO(b/26344050): tighten the alignment based on the real element type.
    129   if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
    130     return llvm_ir::EmitGetTupleElement(
    131         gte->shape(), gte->tuple_index(), /*alignment=*/1,
    132         GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
    133   }
    134   return llvm_ir::EmitGetTupleElement(
    135       gte->shape(), gte->tuple_index(), /*alignment=*/1,
    136       EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
    137 }
    138 
    139 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
    140                                               const ShapeIndex& shape_index,
    141                                               llvm::Value* ir_value) {
    142   llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
    143       ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
    144   llvm::Type* dest_type = pointee_type->getPointerTo();
    145 
    146   llvm::Value* typed_ir_value;
    147   if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
    148     typed_ir_value = llvm::ConstantExpr::getBitCast(
    149         llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
    150   } else {
    151     typed_ir_value =
    152         ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo());
    153   }
    154   ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
    155   typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));
    156   return typed_ir_value;
    157 }
    158 
    159 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
    160                                        llvm::Value* ir_value,
    161                                        const ShapeIndex& shape_index) {
    162   VLOG(2) << "Binding " << hlo.ToString();
    163 
    164   const Shape& hlo_shape = hlo.shape();
    165   llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
    166 
    167   if (!BoundToIrValue(hlo)) {
    168     // Set the root of ShapeTree first before assigning the element ir value.
    169     InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
    170   }
    171   *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
    172 }
    173 
    174 // Determines whether hlo's buffers are never modified within the execution of
    175 // consumer.
    176 static bool BuffersInvariantWithinConsumer(
    177     const HloInstruction& hlo, const HloInstruction& consumer,
    178     const BufferAssignment* buffer_assignment) {
    179   // Check if consumer is inside a fusion node -- if so, "dereference" it until
    180   // we get to a non-fusion node.
    181   const HloInstruction* c = &consumer;
    182   while (c->IsFused()) {
    183     c = c->parent()->FusionInstruction();
    184   }
    185 
    186   // If, after dereferencing c, we end up with a node that's not inside our
    187   // module's top-level computation (say our node is inside a while loop), we
    188   // give up on marking array as invariant, because this HLO may be run multiple
    189   // times (e.g. multiple while loop iterations, or multiple invocations of a
    190   // reducer's computation).  TODO(jlebar): We could relax this constraint if we
    191   // emitted an llvm.invariant.group.barrier at the end of the computation.
    192   return c->parent() == c->GetModule()->entry_computation() &&
    193          buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
    194 }
    195 
    196 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
    197                                              const HloInstruction& consumer,
    198                                              const ShapeIndex& shape_index) {
    199   llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
    200   CHECK_NE(base_ptr, nullptr)
    201       << "Buffer not assigned for shape_index " << shape_index.ToString()
    202       << " of " << hlo.ToString();
    203   llvm_ir::IrArray ir_array(base_ptr,
    204                             ShapeUtil::GetSubshape(hlo.shape(), shape_index));
    205   alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
    206 
    207   // The GPU backend emits one kernel per top-level HLO, and LLVM views
    208   // execution of one kernel as the "whole program" executed on the GPU.
    209   // Therefore if hlo's output buffer is not modified within consumer, and if
    210   // consumer runs hlo only once (so that it doesn't create two different
    211   // outputs), then we can mark ir_array as invariant over the whole program.
    212   if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
    213     VLOG(2) << "Marking " << hlo.name() << " as invariant within "
    214             << consumer.name();
    215     ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
    216   }
    217 
    218   return ir_array;
    219 }
    220 
    221 void HloToIrBindings::UnbindAllLocalIrValues() {
    222   std::vector<const HloInstruction*> hlos_to_unbind;
    223   for (auto& key_value : base_ptrs_) {
    224     if (!llvm::isa<llvm::GlobalVariable>(
    225             (key_value.second.element({}))->stripPointerCasts())) {
    226       hlos_to_unbind.push_back(key_value.first);
    227     }
    228   }
    229   for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
    230     VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
    231     base_ptrs_.erase(hlo_to_unbind);
    232   }
    233 }
    234 
    235 string HloToIrBindings::ToString() const {
    236   string s = StrCat("** HloToIrBindings **\n");
    237   StrAppend(&s, "  is_nested_=", is_nested_, "\n");
    238   StrAppend(&s,
    239             "  temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
    240             "\n");
    241 
    242   if (base_ptrs_.empty()) {
    243     return s;
    244   }
    245 
    246   // Iterate over all computations in the module in topological order, and print
    247   // out the base pointers we have in each computation in topological order.
    248   for (const HloComputation* computation :
    249        base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
    250     bool is_first = true;
    251     for (const HloInstruction* instr :
    252          computation->MakeInstructionPostOrder()) {
    253       auto it = base_ptrs_.find(instr);
    254       if (it == base_ptrs_.end()) {
    255         continue;
    256       }
    257       if (is_first) {
    258         StrAppend(&s, "  Base pointers for computation ", computation->name(),
    259                   ":\n");
    260         is_first = false;
    261       }
    262       StrAppend(&s, "    ", instr->ToString());
    263 
    264       const ShapeTree<llvm::Value*>& shape_tree = it->second;
    265       if (!ShapeUtil::IsTuple(instr->shape())) {
    266         const llvm::Value* val = shape_tree.begin()->second;
    267         StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
    268         continue;
    269       }
    270 
    271       StrAppend(&s, "\n");
    272       for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
    273            ++shape_it) {
    274         llvm::Value* val = shape_it->second;
    275         StrAppend(&s, "      ", shape_it->first.ToString(), " -> ",
    276                   (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
    277                   "\n");
    278       }
    279     }
    280   }
    281   return s;
    282 }
    283 
    284 }  // namespace gpu
    285 }  // namespace xla
    286