Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     17 
     18 #include "absl/container/flat_hash_set.h"
     19 #include "absl/strings/str_cat.h"
     20 #include "llvm/IR/BasicBlock.h"
     21 #include "llvm/IR/Function.h"
     22 #include "llvm/IR/Instructions.h"
     23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
     24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
     27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace xla {
     33 namespace gpu {
     34 
     35 using absl::StrAppend;
     36 using absl::StrCat;
     37 
     38 void HloToIrBindings::EmitBasePointersForHlos(
     39     absl::Span<const HloInstruction* const> io_hlos,
     40     absl::Span<const HloInstruction* const> non_io_hlos) {
     41   // I/O HLOs are bound to the arguments of the current IR function. I.e.,
     42   //
     43   // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
     44   llvm::Function* function = b_->GetInsertBlock()->getParent();
     45   CHECK_EQ(io_hlos.size() + 1, function->arg_size());
     46 
     47   // An HLO can have duplicated operands. This data structure remembers which
     48   // operand HLOs are already bound to avoid rebinding the same HLO.
     49   absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
     50   auto arg_iter = function->arg_begin();
     51   for (const HloInstruction* io_hlo : io_hlos) {
     52     if (!already_bound_for_this_function.contains(io_hlo)) {
     53       if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
     54         BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
     55       } else {
     56         BindHloToIrValue(*io_hlo, &*arg_iter);
     57       }
     58       already_bound_for_this_function.insert(io_hlo);
     59     }
     60     ++arg_iter;
     61   }
     62 
     63   temp_buffer_base_ = &*arg_iter;
     64   temp_buffer_base_->setName("temp_buffer");
     65 
     66   for (const HloInstruction* non_io_hlo : non_io_hlos) {
     67     if (already_bound_for_this_function.contains(non_io_hlo)) {
     68       continue;
     69     }
     70     already_bound_for_this_function.insert(non_io_hlo);
     71 
     72     if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
     73       if (!is_nested_) {
     74         // Lookup allocation GetTupleElement operand.
     75         const BufferAllocation::Slice slice =
     76             buffer_assignment_
     77                 ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor())
     78                 .ConsumeValueOrDie();
     79         // We are not in a nested context, so check non-thread-local allocation.
     80         CHECK(!slice.allocation()->is_thread_local());
     81         const int64 offset = slice.offset();
     82         CHECK_NE(nullptr, temp_buffer_base_);
     83         // Emit IR for GetTupleElement instruction and bind to emitted value.
     84         llvm::Value* base_ptr =
     85             b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset));
     86         BindHloToIrValue(*non_io_hlo,
     87                          EmitGetTupleElement(non_io_hlo, base_ptr));
     88       }
     89       continue;
     90     }
     91 
     92     if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
     93       continue;
     94     }
     95 
     96     ShapeUtil::ForEachSubshape(
     97         non_io_hlo->shape(),
     98         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
     99           // A non-IO HLO with a buffer is bound to
    100           // (1) an alloca if it is thread-local, or
    101           // (2) an internal pointer in temp_buffer_base according to its
    102           // offset.
    103           auto slice_result =
    104               buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
    105           if (!slice_result.ok()) {
    106             return;
    107           }
    108           const BufferAllocation::Slice slice =
    109               slice_result.ConsumeValueOrDie();
    110           if (slice.allocation()->is_thread_local()) {
    111             llvm::Type* pointee_type =
    112                 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
    113             BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type),
    114                              index);
    115           } else if (slice.allocation()->is_constant()) {
    116             llvm::Value* global_for_constant = module_->getGlobalVariable(
    117                 llvm_ir::ConstantBufferAllocationToGlobalName(
    118                     *slice.allocation()));
    119             BindHloToIrValue(*non_io_hlo, global_for_constant);
    120           } else {
    121             const int64 offset = slice.offset();
    122             CHECK_NE(nullptr, temp_buffer_base_);
    123             BindHloToIrValue(
    124                 *non_io_hlo,
    125                 b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)),
    126                 index);
    127           }
    128         });
    129   }
    130 }
    131 
    132 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
    133                                                   llvm::Value* base_ptr) {
    134   // TODO(b/26344050): tighten the alignment based on the real element type.
    135   if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
    136     return llvm_ir::EmitGetTupleElement(
    137         gte->shape(), gte->tuple_index(), /*alignment=*/1,
    138         GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
    139   }
    140   return llvm_ir::EmitGetTupleElement(
    141       gte->shape(), gte->tuple_index(), /*alignment=*/1,
    142       EmitGetTupleElement(gte->operand(0), base_ptr), b_);
    143 }
    144 
    145 // Returns true if `value` has a name that should not be changed.
    146 static bool HasMeaningfulName(llvm::Value* value) {
    147   if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
    148     return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
    149   }
    150   return false;
    151 }
    152 
    153 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
    154                                               ShapeIndexView shape_index,
    155                                               llvm::Value* ir_value) {
    156   llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
    157       ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
    158   llvm::Type* dest_type = pointee_type->getPointerTo();
    159 
    160   llvm::Value* typed_ir_value;
    161   if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
    162     typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
    163         llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
    164   } else {
    165     typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo());
    166   }
    167   if (!HasMeaningfulName(ir_value)) {
    168     ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
    169   }
    170   if (!HasMeaningfulName(typed_ir_value)) {
    171     typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
    172   }
    173   return typed_ir_value;
    174 }
    175 
    176 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
    177                                        llvm::Value* ir_value,
    178                                        ShapeIndexView shape_index) {
    179   VLOG(2) << "Binding " << hlo.ToString();
    180 
    181   const Shape& hlo_shape = hlo.shape();
    182   llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
    183 
    184   if (!BoundToIrValue(hlo)) {
    185     // Set the root of ShapeTree first before assigning the element ir value.
    186     InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
    187   }
    188   *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
    189 }
    190 
    191 // Determines whether hlo's buffers are never modified within the execution of
    192 // consumer.
    193 static bool BuffersInvariantWithinConsumer(
    194     const HloInstruction& hlo, const HloInstruction& consumer,
    195     const BufferAssignment* buffer_assignment) {
    196   // Check if consumer is inside a fusion node -- if so, "dereference" it until
    197   // we get to a non-fusion node.
    198   const HloInstruction* c = &consumer;
    199   while (c->IsFused()) {
    200     c = c->parent()->FusionInstruction();
    201   }
    202 
    203   // If, after dereferencing c, we end up with a node that's not inside our
    204   // module's top-level computation (say our node is inside a while loop), we
    205   // give up on marking array as invariant, because this HLO may be run multiple
    206   // times (e.g. multiple while loop iterations, or multiple invocations of a
    207   // reducer's computation).  TODO(jlebar): We could relax this constraint if we
    208   // emitted an llvm.invariant.group.barrier at the end of the computation.
    209   return c->parent() == c->GetModule()->entry_computation() &&
    210          buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
    211 }
    212 
    213 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
    214                                              const HloInstruction& consumer,
    215                                              const ShapeIndex& shape_index) {
    216   llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
    217   CHECK_NE(base_ptr, nullptr)
    218       << "Buffer not assigned for shape_index " << shape_index.ToString()
    219       << " of " << hlo.ToString();
    220   llvm_ir::IrArray ir_array(base_ptr,
    221                             ShapeUtil::GetSubshape(hlo.shape(), shape_index));
    222   alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
    223 
    224   // The GPU backend emits one kernel per top-level HLO, and LLVM views
    225   // execution of one kernel as the "whole program" executed on the GPU.
    226   // Therefore if hlo's output buffer is not modified within consumer, and if
    227   // consumer runs hlo only once (so that it doesn't create two different
    228   // outputs), then we can mark ir_array as invariant over the whole program.
    229   if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
    230     VLOG(2) << "Marking " << hlo.name() << " as invariant within "
    231             << consumer.name();
    232     ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
    233   }
    234 
    235   return ir_array;
    236 }
    237 
    238 void HloToIrBindings::UnbindAllLocalIrValues() {
    239   std::vector<const HloInstruction*> hlos_to_unbind;
    240   for (auto& key_value : base_ptrs_) {
    241     if (!llvm::isa<llvm::GlobalVariable>(
    242             (key_value.second.element({}))->stripPointerCasts())) {
    243       hlos_to_unbind.push_back(key_value.first);
    244     }
    245   }
    246   for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
    247     VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
    248     base_ptrs_.erase(hlo_to_unbind);
    249   }
    250 }
    251 
    252 string HloToIrBindings::ToString() const {
    253   string s = StrCat("** HloToIrBindings **\n");
    254   StrAppend(&s, "  is_nested_=", is_nested_, "\n");
    255   StrAppend(&s,
    256             "  temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
    257             "\n");
    258 
    259   if (base_ptrs_.empty()) {
    260     return s;
    261   }
    262 
    263   // Iterate over all computations in the module in topological order, and print
    264   // out the base pointers we have in each computation in topological order.
    265   for (const HloComputation* computation :
    266        base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
    267     bool is_first = true;
    268     for (const HloInstruction* instr :
    269          computation->MakeInstructionPostOrder()) {
    270       auto it = base_ptrs_.find(instr);
    271       if (it == base_ptrs_.end()) {
    272         continue;
    273       }
    274       if (is_first) {
    275         StrAppend(&s, "  Base pointers for computation ", computation->name(),
    276                   ":\n");
    277         is_first = false;
    278       }
    279       StrAppend(&s, "    ", instr->ToString());
    280 
    281       const ShapeTree<llvm::Value*>& shape_tree = it->second;
    282       if (!instr->shape().IsTuple()) {
    283         const llvm::Value* val = shape_tree.begin()->second;
    284         StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
    285         continue;
    286       }
    287 
    288       StrAppend(&s, "\n");
    289       for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
    290            ++shape_it) {
    291         llvm::Value* val = shape_it->second;
    292         StrAppend(&s, "      ", shape_it->first.ToString(), " -> ",
    293                   (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
    294                   "\n");
    295       }
    296     }
    297   }
    298   return s;
    299 }
    300 
    301 }  // namespace gpu
    302 }  // namespace xla
    303