1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 17 18 #include "llvm/IR/BasicBlock.h" 19 #include "llvm/IR/Function.h" 20 #include "llvm/IR/Instructions.h" 21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 22 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 24 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 25 #include "tensorflow/core/lib/strings/str_util.h" 26 #include "tensorflow/core/lib/strings/strcat.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 namespace gpu { 32 33 using tensorflow::strings::StrAppend; 34 using tensorflow::strings::StrCat; 35 36 void HloToIrBindings::EmitBasePointersForHlos( 37 tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos, 38 tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) { 39 // I/O HLOs are bound to the arguments of the current IR function. I.e., 40 // 41 // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { 42 llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); 43 CHECK_EQ(io_hlos.size() + 1, function->arg_size()); 44 45 // An HLO can have duplicated operands. This data structure remembers which 46 // operand HLOs are already bound to avoid rebinding the same HLO. 47 std::set<const HloInstruction*> already_bound_for_this_function; 48 auto arg_iter = function->arg_begin(); 49 for (const HloInstruction* io_hlo : io_hlos) { 50 if (!already_bound_for_this_function.count(io_hlo)) { 51 if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { 52 BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); 53 } else { 54 BindHloToIrValue(*io_hlo, &*arg_iter); 55 } 56 already_bound_for_this_function.insert(io_hlo); 57 } 58 ++arg_iter; 59 } 60 61 temp_buffer_base_ = &*arg_iter; 62 temp_buffer_base_->setName("temp_buffer"); 63 64 for (const HloInstruction* non_io_hlo : non_io_hlos) { 65 if (already_bound_for_this_function.count(non_io_hlo)) { 66 continue; 67 } 68 already_bound_for_this_function.insert(non_io_hlo); 69 70 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) { 71 if (!is_nested_) { 72 // Lookup allocation GetTupleElement operand. 73 const BufferAllocation::Slice slice = 74 buffer_assignment_ 75 ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) 76 .ConsumeValueOrDie(); 77 // We are not in a nested context, so check non-thread-local allocation. 78 CHECK(!slice.allocation()->is_thread_local()); 79 const int64 offset = slice.offset(); 80 CHECK_NE(nullptr, temp_buffer_base_); 81 // Emit IR for GetTupleElement instruction and bind to emitted value. 82 llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( 83 temp_buffer_base_, ir_builder_->getInt64(offset)); 84 BindHloToIrValue(*non_io_hlo, 85 EmitGetTupleElement(non_io_hlo, base_ptr)); 86 } 87 continue; 88 } 89 90 if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) { 91 continue; 92 } 93 94 ShapeUtil::ForEachSubshape( 95 non_io_hlo->shape(), 96 [&](const Shape& /*subshape*/, const ShapeIndex& index) { 97 // A non-IO HLO with a buffer is bound to 98 // (1) an alloca if it is thread-local, or 99 // (2) an internal pointer in temp_buffer_base according to its 100 // offset. 101 auto slice_result = 102 buffer_assignment_->GetUniqueSlice(non_io_hlo, index); 103 if (!slice_result.ok()) { 104 return; 105 } 106 const BufferAllocation::Slice slice = 107 slice_result.ConsumeValueOrDie(); 108 if (slice.allocation()->is_thread_local()) { 109 llvm::Type* pointee_type = 110 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); 111 BindHloToIrValue(*non_io_hlo, 112 ir_builder_->CreateAlloca(pointee_type), index); 113 } else { 114 const int64 offset = slice.offset(); 115 CHECK_NE(nullptr, temp_buffer_base_); 116 BindHloToIrValue( 117 *non_io_hlo, 118 ir_builder_->CreateInBoundsGEP(temp_buffer_base_, 119 ir_builder_->getInt64(offset)), 120 index); 121 } 122 }); 123 } 124 } 125 126 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, 127 llvm::Value* base_ptr) { 128 // TODO(b/26344050): tighten the alignment based on the real element type. 129 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { 130 return llvm_ir::EmitGetTupleElement( 131 gte->shape(), gte->tuple_index(), /*alignment=*/1, 132 GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); 133 } 134 return llvm_ir::EmitGetTupleElement( 135 gte->shape(), gte->tuple_index(), /*alignment=*/1, 136 EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); 137 } 138 139 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, 140 const ShapeIndex& shape_index, 141 llvm::Value* ir_value) { 142 llvm::Type* pointee_type = llvm_ir::ShapeToIrType( 143 ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); 144 llvm::Type* dest_type = pointee_type->getPointerTo(); 145 146 llvm::Value* typed_ir_value; 147 if (llvm::isa<llvm::GlobalVariable>(ir_value)) { 148 typed_ir_value = llvm::ConstantExpr::getBitCast( 149 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type); 150 } else { 151 typed_ir_value = 152 ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo()); 153 } 154 ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); 155 typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); 156 return typed_ir_value; 157 } 158 159 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, 160 llvm::Value* ir_value, 161 const ShapeIndex& shape_index) { 162 VLOG(2) << "Binding " << hlo.ToString(); 163 164 const Shape& hlo_shape = hlo.shape(); 165 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value); 166 167 if (!BoundToIrValue(hlo)) { 168 // Set the root of ShapeTree first before assigning the element ir value. 169 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr)); 170 } 171 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; 172 } 173 174 // Determines whether hlo's buffers are never modified within the execution of 175 // consumer. 176 static bool BuffersInvariantWithinConsumer( 177 const HloInstruction& hlo, const HloInstruction& consumer, 178 const BufferAssignment* buffer_assignment) { 179 // Check if consumer is inside a fusion node -- if so, "dereference" it until 180 // we get to a non-fusion node. 181 const HloInstruction* c = &consumer; 182 while (c->IsFused()) { 183 c = c->parent()->FusionInstruction(); 184 } 185 186 // If, after dereferencing c, we end up with a node that's not inside our 187 // module's top-level computation (say our node is inside a while loop), we 188 // give up on marking array as invariant, because this HLO may be run multiple 189 // times (e.g. multiple while loop iterations, or multiple invocations of a 190 // reducer's computation). TODO(jlebar): We could relax this constraint if we 191 // emitted an llvm.invariant.group.barrier at the end of the computation. 192 return c->parent() == c->GetModule()->entry_computation() && 193 buffer_assignment->HaveDisjointSlices(&hlo, &consumer); 194 } 195 196 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, 197 const HloInstruction& consumer, 198 const ShapeIndex& shape_index) { 199 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); 200 CHECK_NE(base_ptr, nullptr) 201 << "Buffer not assigned for shape_index " << shape_index.ToString() 202 << " of " << hlo.ToString(); 203 llvm_ir::IrArray ir_array(base_ptr, 204 ShapeUtil::GetSubshape(hlo.shape(), shape_index)); 205 alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); 206 207 // The GPU backend emits one kernel per top-level HLO, and LLVM views 208 // execution of one kernel as the "whole program" executed on the GPU. 209 // Therefore if hlo's output buffer is not modified within consumer, and if 210 // consumer runs hlo only once (so that it doesn't create two different 211 // outputs), then we can mark ir_array as invariant over the whole program. 212 if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { 213 VLOG(2) << "Marking " << hlo.name() << " as invariant within " 214 << consumer.name(); 215 ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); 216 } 217 218 return ir_array; 219 } 220 221 void HloToIrBindings::UnbindAllLocalIrValues() { 222 std::vector<const HloInstruction*> hlos_to_unbind; 223 for (auto& key_value : base_ptrs_) { 224 if (!llvm::isa<llvm::GlobalVariable>( 225 (key_value.second.element({}))->stripPointerCasts())) { 226 hlos_to_unbind.push_back(key_value.first); 227 } 228 } 229 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) { 230 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString(); 231 base_ptrs_.erase(hlo_to_unbind); 232 } 233 } 234 235 string HloToIrBindings::ToString() const { 236 string s = StrCat("** HloToIrBindings **\n"); 237 StrAppend(&s, " is_nested_=", is_nested_, "\n"); 238 StrAppend(&s, 239 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_), 240 "\n"); 241 242 if (base_ptrs_.empty()) { 243 return s; 244 } 245 246 // Iterate over all computations in the module in topological order, and print 247 // out the base pointers we have in each computation in topological order. 248 for (const HloComputation* computation : 249 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) { 250 bool is_first = true; 251 for (const HloInstruction* instr : 252 computation->MakeInstructionPostOrder()) { 253 auto it = base_ptrs_.find(instr); 254 if (it == base_ptrs_.end()) { 255 continue; 256 } 257 if (is_first) { 258 StrAppend(&s, " Base pointers for computation ", computation->name(), 259 ":\n"); 260 is_first = false; 261 } 262 StrAppend(&s, " ", instr->ToString()); 263 264 const ShapeTree<llvm::Value*>& shape_tree = it->second; 265 if (!ShapeUtil::IsTuple(instr->shape())) { 266 const llvm::Value* val = shape_tree.begin()->second; 267 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); 268 continue; 269 } 270 271 StrAppend(&s, "\n"); 272 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end(); 273 ++shape_it) { 274 llvm::Value* val = shape_it->second; 275 StrAppend(&s, " ", shape_it->first.ToString(), " -> ", 276 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"), 277 "\n"); 278 } 279 } 280 } 281 return s; 282 } 283 284 } // namespace gpu 285 } // namespace xla 286