1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 17 18 #include <utility> 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/liveness_util.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/status_macros.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/strings/str_util.h" 30 #include "tensorflow/core/lib/strings/stringprintf.h" 31 #include "tensorflow/core/platform/logging.h" 32 33 namespace xla { 34 35 bool HloOrdering::ExecutesBefore(const HloInstruction* a, 36 const HloInstruction* b) const { 37 // 'a' and 'b' may be in different computations. In this case, find the 38 // callgraph ancestor instructions which call (potentially transitively) the 39 // computations containing 'a' and 'b' and use these ancestor instructions to 40 // compare order. 41 const HloInstruction* a_ancestor; 42 const HloInstruction* b_ancestor; 43 std::tie(a_ancestor, b_ancestor) = 44 call_graph_->NearestAncestorsInSameComputation( 45 const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b)); 46 47 if (a_ancestor == nullptr) { 48 // Ancestors in a common computation could not be found so consider the 49 // instructions 'a' and 'b' to be unordered. 50 return false; 51 } 52 // a_ancestor and b_ancestor must be either both null or both non-null. 53 CHECK_NE(b_ancestor, nullptr); 54 CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); 55 56 // If the common ancestor is a while instruction there is an additional 57 // ordering criteria which may apply. The condition computation is considered 58 // to execute before the body computation so if 'a' is in the condition and 59 // 'b' is in the body, then 'a' executes before 'b'. 60 if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) { 61 const HloComputation* body = a_ancestor->while_body(); 62 const HloComputation* condition = a_ancestor->while_condition(); 63 if (call_graph_->InstructionIsNestedIn(a, condition) && 64 call_graph_->InstructionIsNestedIn(b, body)) { 65 return true; 66 } 67 } 68 69 return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); 70 } 71 72 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { 73 // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' 74 // is live into the module. 75 const HloModule* module = b.defining_instruction()->parent()->parent(); 76 if (b.defining_instruction()->parent() == module->entry_computation() && 77 b.defining_instruction()->opcode() == HloOpcode::kParameter) { 78 return false; 79 } 80 81 // Phi values require special handling. Because XLA does not have a phi 82 // instruction, the definition instruction of the phis values are 83 // placeholders: either the subcomputation parameter (body or condition) or 84 // the while instruction. However, the program point where these values are 85 // logically defined does not necessarily coincide exactly with program point 86 // of these place-holder instructions. So we explicitly define the following 87 // order for phi values: 88 // 89 // body/condition parameter phi: 90 // Defined before all values defined in its computation excepting other 91 // phis. 92 // 93 // while phi: 94 // defined after all values defined in the condition or body. 95 // 96 auto is_body_or_condition_phi = [](const HloValue& v) { 97 return v.is_phi() && 98 v.defining_instruction()->opcode() == HloOpcode::kParameter; 99 }; 100 if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && 101 call_graph_->InstructionIsNestedIn(b.defining_instruction(), 102 a.defining_instruction()->parent())) { 103 return true; 104 } 105 if (is_body_or_condition_phi(b) && 106 call_graph_->InstructionIsNestedIn(a.defining_instruction(), 107 b.defining_instruction()->parent())) { 108 return false; 109 } 110 111 // If 'b' is a while phi and 'a' is in the body or condition, then 'a' 112 // executes before 'b'. 113 if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && 114 (call_graph_->InstructionIsNestedIn( 115 a.defining_instruction(), b.defining_instruction()->while_body()) || 116 call_graph_->InstructionIsNestedIn( 117 a.defining_instruction(), 118 b.defining_instruction()->while_condition()))) { 119 return true; 120 } 121 122 return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); 123 } 124 125 /* static */ 126 bool HloOrdering::UseIsBeforeValueDefinition( 127 const HloUse& use, const HloValue& value, 128 const HloDataflowAnalysis& dataflow) const { 129 VLOG(4) << "UseIsBeforeValueDefinition(use=" << use 130 << ", value=" << value.ToShortString() << ")"; 131 if (ExecutesBefore(use.instruction, value.defining_instruction())) { 132 VLOG(4) << " use instruction executes before value-defining instruction"; 133 return true; 134 } 135 136 // If the use is at the instruction where the value is defined, then the use 137 // is before the def if the instruction allows buffer sharing (in place 138 // computation). 139 if (use.instruction == value.defining_instruction() && 140 CanShareOperandBufferWithUser( 141 use.instruction->mutable_operand(use.operand_number), 142 use.operand_index, value.defining_instruction(), 143 value.defining_index(), dataflow)) { 144 VLOG(4) << " use is value def, and instruction can share use buffer"; 145 return true; 146 } 147 148 // The use at a while is an input to a phi, and logically occurs before values 149 // are defined in the body or condition computations. 150 if (use.instruction->opcode() == HloOpcode::kWhile) { 151 const HloInstruction* xla_while = use.instruction; 152 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), 153 xla_while->while_body()) || 154 call_graph_->InstructionIsNestedIn(value.defining_instruction(), 155 xla_while->while_condition())) { 156 VLOG(4) << " use is while " << use.instruction->name() 157 << " and def is in condition or body"; 158 return true; 159 } 160 } 161 162 // Similarly if the value is defined at a while, it logically occurs after any 163 // uses in the body or condition computations. 164 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { 165 CHECK(value.is_phi()); 166 const HloInstruction* xla_while = value.defining_instruction(); 167 if (call_graph_->InstructionIsNestedIn(use.instruction, 168 xla_while->while_body()) || 169 call_graph_->InstructionIsNestedIn(use.instruction, 170 xla_while->while_condition())) { 171 VLOG(4) << " value is while " << value.defining_instruction()->name() 172 << " and use is in condition or body"; 173 return true; 174 } 175 } 176 177 // The use at a call occurs before values that are defined in the called 178 // computation. 179 if (use.instruction->opcode() == HloOpcode::kCall) { 180 const HloInstruction* call = use.instruction; 181 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), 182 call->to_apply())) { 183 VLOG(4) << " use is call " << use.instruction->name() 184 << " and def is in called computation"; 185 return true; 186 } 187 } 188 189 if (use.instruction->opcode() == HloOpcode::kConditional) { 190 const HloInstruction* conditional = use.instruction; 191 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), 192 conditional->true_computation())) { 193 VLOG(4) << " use is conditional " << use.instruction->name() 194 << " and def is in TRUE computation"; 195 return true; 196 } 197 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), 198 conditional->false_computation())) { 199 VLOG(4) << " use is conditional " << use.instruction->name() 200 << " and def is in FALSE computation"; 201 return true; 202 } 203 } 204 205 VLOG(4) << " use is not before value"; 206 return false; 207 } 208 209 bool HloOrdering::LiveRangeStrictlyBefore( 210 const HloValue& a, const HloValue& b, 211 const HloDataflowAnalysis& dataflow) const { 212 VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() 213 << ", b = " << b.ToShortString() << ")"; 214 if (!IsDefinedBefore(a, b)) { 215 VLOG(4) << "a not defined before b"; 216 return false; 217 } 218 219 // All uses of 'a' must be before 'b' is defined. 220 for (const HloUse& use : a.uses()) { 221 if (!UseIsBeforeValueDefinition(use, b, dataflow)) { 222 VLOG(4) << "use of a (" << use << ") not before b is defined"; 223 return false; 224 } 225 } 226 227 return true; 228 } 229 230 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, 231 const HloDataflowAnalysis& dataflow) const { 232 // Buffers without disjoint liveness may interfere. 233 return !LiveRangeStrictlyBefore(a, b, dataflow) && 234 !LiveRangeStrictlyBefore(b, a, dataflow); 235 } 236 237 HloOrderingProto HloOrdering::ToProto() const { 238 HloOrderingProto proto; 239 for (const auto& computation : module_->computations()) { 240 const std::vector<const HloInstruction*>* sequence = 241 SequentialOrder(*computation); 242 if (sequence != nullptr) { 243 HloOrderingProto::SequentialComputation* proto_computation = 244 proto.add_sequential_computations(); 245 proto_computation->set_computation_name(computation->name()); 246 for (const HloInstruction* instruction : *sequence) { 247 *proto_computation->add_instruction_names() = instruction->name(); 248 } 249 } 250 } 251 return proto; 252 } 253 254 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) 255 : HloOrdering(module) {} 256 257 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( 258 const HloInstruction* a, const HloInstruction* b) const { 259 CHECK_EQ(a->parent(), b->parent()); 260 261 // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. 262 return a != b && predecessors_.at(a->parent())->IsReachable(a, b); 263 } 264 265 string PredecessorHloOrdering::ToStringHelper(const string& name) const { 266 std::vector<string> pieces; 267 pieces.push_back(name); 268 for (auto* computation : module_->MakeNonfusionComputations()) { 269 pieces.push_back(tensorflow::strings::Printf("computation %s:", 270 computation->name().c_str())); 271 const auto all = computation->MakeInstructionPostOrder(); 272 for (auto instruction : all) { 273 pieces.push_back(tensorflow::strings::Printf( 274 " %s predecessors:", instruction->name().c_str())); 275 for (auto predecessor : all) { 276 if (predecessors_.at(computation) 277 ->IsReachable(predecessor, instruction)) { 278 pieces.push_back( 279 tensorflow::strings::Printf(" %s", predecessor->name().c_str())); 280 } 281 } 282 } 283 } 284 return tensorflow::str_util::Join(pieces, "\n"); 285 } 286 287 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) 288 : PredecessorHloOrdering(module) { 289 // Compute predecessor relationships between all instructions to determine 290 // ordering based on dependencies. ExecutesBefore will return true iff there 291 // exists a path in the HLO computation graph from 'a' to 'b'. 292 for (auto* computation : module->MakeNonfusionComputations()) { 293 predecessors_.emplace(computation, computation->ComputeReachability()); 294 } 295 } 296 297 string DependencyHloOrdering::ToString() const { 298 return ToStringHelper("DependencyHloOrdering"); 299 } 300 301 SequentialHloOrdering::SequentialHloOrdering( 302 const HloModule* module, const HloModuleSequence& module_sequence) 303 : HloOrdering(module), module_sequence_(module_sequence) { 304 // Create a map from instruction to its order position. 305 for (auto computation_order : module_sequence_) { 306 const std::vector<const HloInstruction*>& order = computation_order.second; 307 for (int i = 0; i < order.size(); ++i) { 308 DCHECK_EQ(0, order_position_.count(order[i])); 309 order_position_.emplace(order[i], i); 310 } 311 } 312 } 313 314 bool SequentialHloOrdering::ExecutesBeforeInSameComputation( 315 const HloInstruction* a, const HloInstruction* b) const { 316 CHECK_EQ(a->parent(), b->parent()); 317 // If either instruction is not in the order, then 'a' and 'b' are unordered. 318 if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { 319 return false; 320 } 321 return order_position_.at(a) < order_position_.at(b); 322 } 323 324 const std::vector<const HloInstruction*>* 325 SequentialHloOrdering::SequentialOrder( 326 const HloComputation& computation) const { 327 auto find_it = module_sequence_.find(&computation); 328 return find_it == module_sequence_.end() ? nullptr : &find_it->second; 329 } 330 331 string SequentialHloOrdering::ToString() const { 332 std::vector<string> pieces; 333 pieces.push_back("SequentialHloOrdering"); 334 for (auto* computation : module_->computations()) { 335 pieces.push_back(tensorflow::strings::Printf("computation %s order:", 336 computation->name().c_str())); 337 // Gather all instructions in the module sequence for this computation and 338 // sort them by their position. 339 std::vector<const HloInstruction*> instructions; 340 for (auto& instruction_position : order_position_) { 341 const HloInstruction* instruction = instruction_position.first; 342 if (instruction->parent() == computation) { 343 instructions.push_back(instruction); 344 } 345 } 346 std::sort(instructions.begin(), instructions.end(), 347 [this](const HloInstruction* a, const HloInstruction* b) { 348 return order_position_.at(a) < order_position_.at(b); 349 }); 350 for (auto instruction : instructions) { 351 pieces.push_back( 352 tensorflow::strings::Printf(" %s", instruction->name().c_str())); 353 } 354 } 355 return tensorflow::str_util::Join(pieces, "\n"); 356 } 357 358 std::ostream& operator<<( 359 std::ostream& out, 360 const SequentialHloOrdering::HloModuleSequence& module_sequence) { 361 for (auto computation_pair : module_sequence) { 362 const HloComputation* computation = computation_pair.first; 363 const std::vector<const HloInstruction*>& computation_sequence = 364 computation_pair.second; 365 out << "Computation " << computation->name() << ":\n"; 366 for (auto* instruction : computation_sequence) { 367 out << " " << instruction->name() << "\n"; 368 } 369 } 370 return out; 371 } 372 373 } // namespace xla 374