1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 17 18 #include "absl/container/flat_hash_set.h" 19 #include "absl/strings/str_cat.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/Instructions.h" 23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 25 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" 27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace xla { 33 namespace gpu { 34 35 using absl::StrAppend; 36 using absl::StrCat; 37 38 void HloToIrBindings::EmitBasePointersForHlos( 39 absl::Span<const HloInstruction* const> io_hlos, 40 absl::Span<const HloInstruction* const> non_io_hlos) { 41 // I/O HLOs are bound to the arguments of the current IR function. I.e., 42 // 43 // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { 44 llvm::Function* function = b_->GetInsertBlock()->getParent(); 45 CHECK_EQ(io_hlos.size() + 1, function->arg_size()); 46 47 // An HLO can have duplicated operands. This data structure remembers which 48 // operand HLOs are already bound to avoid rebinding the same HLO. 49 absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function; 50 auto arg_iter = function->arg_begin(); 51 for (const HloInstruction* io_hlo : io_hlos) { 52 if (!already_bound_for_this_function.contains(io_hlo)) { 53 if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { 54 BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); 55 } else { 56 BindHloToIrValue(*io_hlo, &*arg_iter); 57 } 58 already_bound_for_this_function.insert(io_hlo); 59 } 60 ++arg_iter; 61 } 62 63 temp_buffer_base_ = &*arg_iter; 64 temp_buffer_base_->setName("temp_buffer"); 65 66 for (const HloInstruction* non_io_hlo : non_io_hlos) { 67 if (already_bound_for_this_function.contains(non_io_hlo)) { 68 continue; 69 } 70 already_bound_for_this_function.insert(non_io_hlo); 71 72 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) { 73 if (!is_nested_) { 74 // Lookup allocation GetTupleElement operand. 75 const BufferAllocation::Slice slice = 76 buffer_assignment_ 77 ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) 78 .ConsumeValueOrDie(); 79 // We are not in a nested context, so check non-thread-local allocation. 80 CHECK(!slice.allocation()->is_thread_local()); 81 const int64 offset = slice.offset(); 82 CHECK_NE(nullptr, temp_buffer_base_); 83 // Emit IR for GetTupleElement instruction and bind to emitted value. 84 llvm::Value* base_ptr = 85 b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)); 86 BindHloToIrValue(*non_io_hlo, 87 EmitGetTupleElement(non_io_hlo, base_ptr)); 88 } 89 continue; 90 } 91 92 if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) { 93 continue; 94 } 95 96 ShapeUtil::ForEachSubshape( 97 non_io_hlo->shape(), 98 [&](const Shape& /*subshape*/, const ShapeIndex& index) { 99 // A non-IO HLO with a buffer is bound to 100 // (1) an alloca if it is thread-local, or 101 // (2) an internal pointer in temp_buffer_base according to its 102 // offset. 103 auto slice_result = 104 buffer_assignment_->GetUniqueSlice(non_io_hlo, index); 105 if (!slice_result.ok()) { 106 return; 107 } 108 const BufferAllocation::Slice slice = 109 slice_result.ConsumeValueOrDie(); 110 if (slice.allocation()->is_thread_local()) { 111 llvm::Type* pointee_type = 112 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); 113 BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), 114 index); 115 } else if (slice.allocation()->is_constant()) { 116 llvm::Value* global_for_constant = module_->getGlobalVariable( 117 llvm_ir::ConstantBufferAllocationToGlobalName( 118 *slice.allocation())); 119 BindHloToIrValue(*non_io_hlo, global_for_constant); 120 } else { 121 const int64 offset = slice.offset(); 122 CHECK_NE(nullptr, temp_buffer_base_); 123 BindHloToIrValue( 124 *non_io_hlo, 125 b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)), 126 index); 127 } 128 }); 129 } 130 } 131 132 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, 133 llvm::Value* base_ptr) { 134 // TODO(b/26344050): tighten the alignment based on the real element type. 135 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { 136 return llvm_ir::EmitGetTupleElement( 137 gte->shape(), gte->tuple_index(), /*alignment=*/1, 138 GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_); 139 } 140 return llvm_ir::EmitGetTupleElement( 141 gte->shape(), gte->tuple_index(), /*alignment=*/1, 142 EmitGetTupleElement(gte->operand(0), base_ptr), b_); 143 } 144 145 // Returns true if `value` has a name that should not be changed. 146 static bool HasMeaningfulName(llvm::Value* value) { 147 if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) { 148 return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; 149 } 150 return false; 151 } 152 153 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, 154 ShapeIndexView shape_index, 155 llvm::Value* ir_value) { 156 llvm::Type* pointee_type = llvm_ir::ShapeToIrType( 157 ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); 158 llvm::Type* dest_type = pointee_type->getPointerTo(); 159 160 llvm::Value* typed_ir_value; 161 if (llvm::isa<llvm::GlobalVariable>(ir_value)) { 162 typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( 163 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type); 164 } else { 165 typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); 166 } 167 if (!HasMeaningfulName(ir_value)) { 168 ir_value->setName(llvm_ir::IrName(&hlo, "raw")); 169 } 170 if (!HasMeaningfulName(typed_ir_value)) { 171 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed")); 172 } 173 return typed_ir_value; 174 } 175 176 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, 177 llvm::Value* ir_value, 178 ShapeIndexView shape_index) { 179 VLOG(2) << "Binding " << hlo.ToString(); 180 181 const Shape& hlo_shape = hlo.shape(); 182 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value); 183 184 if (!BoundToIrValue(hlo)) { 185 // Set the root of ShapeTree first before assigning the element ir value. 186 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr)); 187 } 188 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; 189 } 190 191 // Determines whether hlo's buffers are never modified within the execution of 192 // consumer. 193 static bool BuffersInvariantWithinConsumer( 194 const HloInstruction& hlo, const HloInstruction& consumer, 195 const BufferAssignment* buffer_assignment) { 196 // Check if consumer is inside a fusion node -- if so, "dereference" it until 197 // we get to a non-fusion node. 198 const HloInstruction* c = &consumer; 199 while (c->IsFused()) { 200 c = c->parent()->FusionInstruction(); 201 } 202 203 // If, after dereferencing c, we end up with a node that's not inside our 204 // module's top-level computation (say our node is inside a while loop), we 205 // give up on marking array as invariant, because this HLO may be run multiple 206 // times (e.g. multiple while loop iterations, or multiple invocations of a 207 // reducer's computation). TODO(jlebar): We could relax this constraint if we 208 // emitted an llvm.invariant.group.barrier at the end of the computation. 209 return c->parent() == c->GetModule()->entry_computation() && 210 buffer_assignment->HaveDisjointSlices(&hlo, &consumer); 211 } 212 213 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, 214 const HloInstruction& consumer, 215 const ShapeIndex& shape_index) { 216 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); 217 CHECK_NE(base_ptr, nullptr) 218 << "Buffer not assigned for shape_index " << shape_index.ToString() 219 << " of " << hlo.ToString(); 220 llvm_ir::IrArray ir_array(base_ptr, 221 ShapeUtil::GetSubshape(hlo.shape(), shape_index)); 222 alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); 223 224 // The GPU backend emits one kernel per top-level HLO, and LLVM views 225 // execution of one kernel as the "whole program" executed on the GPU. 226 // Therefore if hlo's output buffer is not modified within consumer, and if 227 // consumer runs hlo only once (so that it doesn't create two different 228 // outputs), then we can mark ir_array as invariant over the whole program. 229 if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { 230 VLOG(2) << "Marking " << hlo.name() << " as invariant within " 231 << consumer.name(); 232 ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); 233 } 234 235 return ir_array; 236 } 237 238 void HloToIrBindings::UnbindAllLocalIrValues() { 239 std::vector<const HloInstruction*> hlos_to_unbind; 240 for (auto& key_value : base_ptrs_) { 241 if (!llvm::isa<llvm::GlobalVariable>( 242 (key_value.second.element({}))->stripPointerCasts())) { 243 hlos_to_unbind.push_back(key_value.first); 244 } 245 } 246 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) { 247 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString(); 248 base_ptrs_.erase(hlo_to_unbind); 249 } 250 } 251 252 string HloToIrBindings::ToString() const { 253 string s = StrCat("** HloToIrBindings **\n"); 254 StrAppend(&s, " is_nested_=", is_nested_, "\n"); 255 StrAppend(&s, 256 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_), 257 "\n"); 258 259 if (base_ptrs_.empty()) { 260 return s; 261 } 262 263 // Iterate over all computations in the module in topological order, and print 264 // out the base pointers we have in each computation in topological order. 265 for (const HloComputation* computation : 266 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) { 267 bool is_first = true; 268 for (const HloInstruction* instr : 269 computation->MakeInstructionPostOrder()) { 270 auto it = base_ptrs_.find(instr); 271 if (it == base_ptrs_.end()) { 272 continue; 273 } 274 if (is_first) { 275 StrAppend(&s, " Base pointers for computation ", computation->name(), 276 ":\n"); 277 is_first = false; 278 } 279 StrAppend(&s, " ", instr->ToString()); 280 281 const ShapeTree<llvm::Value*>& shape_tree = it->second; 282 if (!instr->shape().IsTuple()) { 283 const llvm::Value* val = shape_tree.begin()->second; 284 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); 285 continue; 286 } 287 288 StrAppend(&s, "\n"); 289 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end(); 290 ++shape_it) { 291 llvm::Value* val = shape_it->second; 292 StrAppend(&s, " ", shape_it->first.ToString(), " -> ", 293 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"), 294 "\n"); 295 } 296 } 297 } 298 return s; 299 } 300 301 } // namespace gpu 302 } // namespace xla 303