1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_computation.h" 17 18 #include <stddef.h> 19 #include <algorithm> 20 #include <functional> 21 #include <list> 22 #include <queue> 23 #include <set> 24 #include <sstream> 25 26 #include "tensorflow/compiler/xla/layout_util.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/ptr_util.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/shape_util.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/compiler/xla/util.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/lib/gtl/flatset.h" 39 #include "tensorflow/core/lib/strings/str_util.h" 40 #include "tensorflow/core/lib/strings/strcat.h" 41 #include "tensorflow/core/platform/logging.h" 42 43 namespace xla { 44 45 using ::tensorflow::strings::StrCat; 46 47 std::unique_ptr<HloComputation> HloComputation::Builder::Build( 48 HloInstruction* root_instruction) { 49 int parameter_count = 0; 50 for (auto& instruction : instructions_) { 51 if (instruction->opcode() == HloOpcode::kParameter) { 52 parameter_count++; 53 } 54 } 55 // If root_instruction is not specified use the last added instruction. 56 HloInstruction* root = 57 root_instruction ? root_instruction : last_added_instruction_; 58 CHECK_NE(nullptr, root); 59 return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, 60 root, fusion_instruction_)); 61 } 62 63 HloComputation::HloComputation( 64 const string& name, int parameter_count, 65 std::vector<std::unique_ptr<HloInstruction>>* instructions, 66 HloInstruction* root_instruction, HloInstruction* fusion_instruction) 67 : name_(name), 68 root_instruction_(root_instruction), 69 fusion_instruction_(fusion_instruction) { 70 param_instructions_.resize(parameter_count, nullptr); 71 bool root_found = false; 72 for (auto& instruction : *instructions) { 73 if (instruction->opcode() == HloOpcode::kParameter) { 74 int64 param_no = instruction->parameter_number(); 75 CHECK(param_no >= 0 && param_no < parameter_count) 76 << "\nERROR: invalid parameter number. Expected [0, " 77 << parameter_count << "), got " << param_no; 78 CHECK(param_instructions_[param_no] == nullptr) 79 << "\nERROR: parameter number " << param_no 80 << " already allocated in this computation"; 81 param_instructions_[param_no] = instruction.get(); 82 } 83 root_found |= instruction.get() == root_instruction_; 84 AddInstructionInternal(std::move(instruction)); 85 } 86 CHECK(root_found) 87 << "\nERROR: root instruction is not present in computation."; 88 } 89 90 HloInstruction* HloComputation::AddInstruction( 91 std::unique_ptr<HloInstruction> instruction) { 92 CHECK(instruction->opcode() != HloOpcode::kParameter) 93 << "Parameter instructions cannot be added to a computation after " 94 << "it has been built"; 95 return AddInstructionInternal(std::move(instruction)); 96 } 97 98 HloInstruction* HloComputation::AddInstructionInternal( 99 std::unique_ptr<HloInstruction> instruction) { 100 if (parent() != nullptr) { 101 instruction->UniquifyName(&parent()->instruction_name_uniquer()); 102 instruction->SetUniqueId(parent()->NewUniqueInstructionId()); 103 } 104 Reparent(instruction.get()); 105 HloInstruction* pinst = instruction.get(); 106 instruction_iterators_[pinst] = 107 instructions_.insert(instructions_.end(), std::move(instruction)); 108 return pinst; 109 } 110 111 HloInstruction* HloComputation::AddParameter( 112 std::unique_ptr<HloInstruction> instruction) { 113 CHECK(instruction->opcode() == HloOpcode::kParameter); 114 CHECK(IsFusionComputation()); 115 CHECK(fusion_instruction_->operand_count() == param_instructions_.size()); 116 instruction->set_parent(this); 117 param_instructions_.push_back(instruction.get()); 118 AddInstructionInternal(std::move(instruction)); 119 return instructions_.back().get(); 120 } 121 122 Status HloComputation::RemoveParameter(int64 param_no) { 123 CHECK_GE(param_no, 0); 124 CHECK_LT(param_no, param_instructions_.size()); 125 CHECK(IsFusionComputation()); 126 HloInstruction* param_instruction = param_instructions_[param_no]; 127 auto param_instruction_iterator = param_instructions_.begin() + param_no; 128 param_instructions_.erase(param_instruction_iterator); 129 // Throw removed fused parameter instruction away. 130 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 131 132 while (param_no < param_instructions_.size()) { 133 param_instruction = param_instructions_[param_no]; 134 string param_name = param_instruction->name(); 135 // Fusion parameters are named foo.param_1, bar.param_2, etc. We are 136 // renumbering the parameters, so replace the final number in the name with 137 // the updated value. 138 const string param_underscore = ".param_"; 139 size_t index = param_name.rfind(param_underscore); 140 if (index == string::npos) { 141 string after_param = name().substr(index + param_underscore.size()); 142 int64 numeric_suffix; 143 if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { 144 param_name = 145 StrCat(param_name.substr(0, index), param_underscore, param_no); 146 } 147 } 148 149 HloInstruction* new_instr = 150 AddInstructionInternal(HloInstruction::CreateParameter( 151 param_no, param_instruction->shape(), param_name)); 152 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); 153 param_instructions_[param_no] = new_instr; 154 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 155 param_no++; 156 } 157 158 return Status::OK(); 159 } 160 161 void HloComputation::Reparent(HloInstruction* instruction) { 162 instruction->set_parent(this); 163 } 164 165 bool HloComputation::IsRemovable(const HloInstruction* instruction) { 166 // If the instruction has control predecessors or successors then we cannot 167 // remove the instruction without violating ordering constraints (added, for 168 // example, to avert interference due to buffer aliasing). 169 if (!instruction->control_predecessors().empty() || 170 !instruction->control_successors().empty()) { 171 return false; 172 } 173 174 if (instruction->opcode() == HloOpcode::kParameter && 175 !IsFusionComputation()) { 176 return false; 177 } 178 179 return true; 180 } 181 182 bool HloComputation::HasSideEffect() const { 183 for (auto* instruction : instructions()) { 184 if (instruction->HasSideEffect()) { 185 return true; 186 } 187 } 188 return false; 189 } 190 191 Status HloComputation::RemoveInstructionAndUnusedOperands( 192 HloInstruction* instruction) { 193 TF_RET_CHECK(root_instruction() != instruction); 194 195 TF_RET_CHECK(instruction->user_count() == 0); 196 TF_RET_CHECK(IsRemovable(instruction)) 197 << "Cannot remove instruction: " << instruction->ToString(); 198 std::unordered_set<HloInstruction*> removed; 199 std::queue<HloInstruction*> worklist; 200 worklist.push(instruction); 201 while (!worklist.empty()) { 202 HloInstruction* item = worklist.front(); 203 worklist.pop(); 204 205 if (removed.count(item) != 0 || item->user_count() != 0 || 206 item == root_instruction() || !IsRemovable(item) || 207 item->HasSideEffect()) { 208 continue; 209 } 210 for (int i = 0; i < item->operand_count(); ++i) { 211 worklist.push(item->mutable_operand(i)); 212 } 213 214 TF_RETURN_IF_ERROR(RemoveInstruction(item)); 215 removed.insert(item); 216 } 217 return Status::OK(); 218 } 219 220 Status HloComputation::RemoveInstruction(HloInstruction* instruction) { 221 VLOG(2) << "Removing instruction " << instruction->name() 222 << " from computation " << name(); 223 TF_RET_CHECK(IsRemovable(instruction)) 224 << "cannot remove instruction: " << instruction->ToString(); 225 TF_RET_CHECK(root_instruction() != instruction) 226 << "cannot remove root instruction " << instruction->name(); 227 TF_RET_CHECK(instruction->user_count() == 0) 228 << "instruction " << instruction->name() 229 << " has users and cannot be removed"; 230 TF_RET_CHECK(instruction->control_predecessors().empty()) 231 << "instruction " << instruction->name() 232 << " has control predecessors and cannot be removed"; 233 TF_RET_CHECK(instruction->control_successors().empty()) 234 << "instruction " << instruction->name() 235 << " has control successors and cannot be removed"; 236 237 TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); 238 auto inst_it = instruction_iterators_.at(instruction); 239 (*inst_it)->set_parent(nullptr); 240 instruction->DetachFromOperands(); 241 instructions_.erase(inst_it); 242 return Status::OK(); 243 } 244 245 void HloComputation::set_root_instruction( 246 HloInstruction* new_root_instruction) { 247 // The shape of the root (ignoring layout) is an invariant of the computation 248 // for non-fusion cases. 249 if (!IsFusionComputation()) { 250 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), 251 root_instruction_->shape())) 252 << new_root_instruction->shape().ShortDebugString() 253 << " is incompatible with " 254 << root_instruction_->shape().ShortDebugString(); 255 } 256 bool root_found = false; 257 for (auto& instruction : instructions_) { 258 if (new_root_instruction == instruction.get()) { 259 root_found = true; 260 break; 261 } 262 } 263 DCHECK(root_found); 264 265 root_instruction_ = new_root_instruction; 266 } 267 268 namespace { 269 270 // Helper class which computes the post order of an expression rooted at a 271 // particular instruction. 272 class InstructionPostOrderer : public DfsHloVisitorWithDefault { 273 public: 274 // added_instructions is the set of instructions which have already been 275 // accounted for in the post order in previous invocations of 276 // GetOrder. Without this mechanism, instructions which are predecessors of 277 // multiple root instructions of the computation can be added to the post 278 // order more than once. 279 static std::list<HloInstruction*> GetOrder( 280 HloInstruction* root, 281 tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) { 282 InstructionPostOrderer orderer(added_instructions); 283 TF_CHECK_OK(root->Accept(&orderer)); 284 return std::move(orderer.post_order_); 285 } 286 287 private: 288 explicit InstructionPostOrderer( 289 tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) 290 : added_instructions_(added_instructions) {} 291 ~InstructionPostOrderer() override {} 292 293 Status DefaultAction(HloInstruction* hlo_instruction) override { 294 if (added_instructions_->count(hlo_instruction) == 0) { 295 post_order_.push_back(hlo_instruction); 296 added_instructions_->insert(hlo_instruction); 297 } 298 return Status::OK(); 299 } 300 301 std::list<HloInstruction*> post_order_; 302 tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_; 303 }; 304 305 // Helper which builds a post order of the HLO call graph. 306 void ComputeComputationPostOrder( 307 HloComputation* computation, 308 tensorflow::gtl::FlatSet<HloComputation*>* visited, 309 std::list<HloComputation*>* post_order) { 310 if (visited->count(computation) > 0) { 311 return; 312 } 313 314 for (auto* instruction : computation->instructions()) { 315 for (HloComputation* called_computation : 316 instruction->called_computations()) { 317 ComputeComputationPostOrder(called_computation, visited, post_order); 318 } 319 } 320 321 visited->insert(computation); 322 post_order->push_back(computation); 323 } 324 325 } // namespace 326 327 std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { 328 std::list<HloInstruction*> post_order; 329 std::list<HloInstruction*> trace_instructions; 330 tensorflow::gtl::FlatSet<HloInstruction*> added_instructions; 331 for (auto& instruction : instructions_) { 332 if (instruction->opcode() == HloOpcode::kTrace) { 333 // Trace instructions aren't handled by the DFS visitor. Add trace 334 // instructions to the post order at the end (necessarily they have no 335 // users). 336 trace_instructions.push_back(instruction.get()); 337 } else if (instruction->users().empty()) { 338 post_order.splice(post_order.end(), 339 InstructionPostOrderer::GetOrder(instruction.get(), 340 &added_instructions)); 341 } 342 } 343 post_order.splice(post_order.end(), trace_instructions); 344 CHECK_EQ(instructions_.size(), post_order.size()) 345 << "number of instructions does not match post order size"; 346 return post_order; 347 } 348 349 std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList() 350 const { 351 tensorflow::gtl::FlatSet<HloComputation*> visited; 352 std::list<HloComputation*> post_order; 353 354 // To avoid special handling of this computation, cast away const of 355 // 'this'. 'this' is immediately removed from the post order after 356 // construction. 357 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited, 358 &post_order); 359 360 // We don't want to include this computation in the post order. 361 CHECK_EQ(this, post_order.back()); 362 post_order.pop_back(); 363 364 return post_order; 365 } 366 367 string HloComputation::ToString(const HloPrintOptions& options) const { 368 std::ostringstream s; 369 for (int i = 0; i < options.indent_amount(); i++) { 370 s << " "; 371 } 372 if (options.print_percent()) { 373 s << "%"; 374 } 375 s << name(); 376 if (options.print_program_shape()) { 377 s << " " << ShapeUtil::HumanString(ComputeProgramShape()); 378 } 379 s << " {\n"; 380 for (const HloInstruction* instruction : MakeInstructionPostOrder()) { 381 for (int i = 0; i < options.indent_amount(); i++) { 382 s << " "; 383 } 384 s << " " << (instruction == root_instruction_ ? "ROOT " : "") 385 << instruction->ToString(options) << "\n"; 386 } 387 for (int i = 0; i < options.indent_amount(); i++) { 388 s << " "; 389 } 390 s << "}"; 391 return s.str(); 392 } 393 394 HloComputationProto HloComputation::ToProto() const { 395 HloComputationProto proto; 396 proto.set_name(name_); 397 for (const HloInstruction* instruction : MakeInstructionPostOrder()) { 398 HloInstructionProto instruction_proto = instruction->ToProto(); 399 proto.add_instructions()->Swap(&instruction_proto); 400 } 401 proto.set_root_name(root_instruction()->name()); 402 return proto; 403 } 404 405 /* static */ StatusOr<std::unique_ptr<HloComputation>> 406 HloComputation::CreateFromProto( 407 HloModule* module, const HloComputationProto& proto, 408 const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, 409 const std::function<void(std::unique_ptr<HloComputation>)>& 410 add_fused_computation, 411 HloInstruction* fusion_instruction) { 412 std::vector<std::unique_ptr<HloInstruction>> instructions; 413 tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map; 414 int64 parameter_count = 0; 415 for (const HloInstructionProto& instruction_proto : proto.instructions()) { 416 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction, 417 HloInstruction::CreateFromProto( 418 module, instruction_proto, instruction_map, 419 computation_map, add_fused_computation)); 420 if (instruction->opcode() == HloOpcode::kParameter) { 421 parameter_count++; 422 } 423 TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); 424 instruction_map[instruction->name()] = instruction.get(); 425 instructions.push_back(std::move(instruction)); 426 } 427 428 TF_RET_CHECK(!proto.root_name().empty()); 429 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); 430 HloInstruction* root = instruction_map.at(proto.root_name()); 431 return WrapUnique(new HloComputation( 432 proto.name(), parameter_count, &instructions, root, fusion_instruction)); 433 } 434 435 void HloComputation::FuseInstructionsInto( 436 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, 437 HloInstruction* fusion_instruction) { 438 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); 439 HloInstruction* root = instructions_to_fuse.front(); 440 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction)); 441 if (root == root_instruction()) { 442 set_root_instruction(fusion_instruction); 443 } 444 TF_CHECK_OK(RemoveInstruction(root)); 445 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) { 446 HloInstruction* instruction = instructions_to_fuse[i]; 447 fusion_instruction->FuseInstruction(instruction); 448 if (instruction->user_count() == 0) { 449 TF_CHECK_OK(RemoveInstruction(instruction)); 450 } 451 } 452 } 453 454 HloInstruction* HloComputation::CreateFusionInstruction( 455 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, 456 HloInstruction::FusionKind fusion_kind) { 457 HloInstruction* root = instructions_to_fuse.front(); 458 HloInstruction* fusion_instruction = AddInstruction( 459 HloInstruction::CreateFusion(root->shape(), fusion_kind, root)); 460 FuseInstructionsInto(instructions_to_fuse, fusion_instruction); 461 return fusion_instruction; 462 } 463 464 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper( 465 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, 466 ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) { 467 if (ShapeUtil::IsArray(instruction->shape())) { 468 if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { 469 // Use kCopy to copy array elements 470 HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( 471 instruction->shape(), HloOpcode::kCopy, instruction)); 472 if (copies_added != nullptr) { 473 *copies_added->mutable_element(*index) = copy; 474 } 475 return copy; 476 } else { 477 // Array elements which are not to be copied are passed through 478 // transparently. 479 return instruction; 480 } 481 } else if (ShapeUtil::IsTuple(instruction->shape())) { 482 std::vector<HloInstruction*> elements; 483 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); 484 i++) { 485 HloInstruction* gte = 486 AddInstruction(HloInstruction::CreateGetTupleElement( 487 ShapeUtil::GetTupleElementShape(instruction->shape(), i), 488 instruction, i)); 489 490 index->push_back(i); 491 TF_ASSIGN_OR_RETURN( 492 HloInstruction * element, 493 DeepCopyHelper(gte, indices_to_copy, copies_added, index)); 494 elements.push_back(element); 495 index->pop_back(); 496 } 497 return AddInstruction(HloInstruction::CreateTuple(elements)); 498 } else { 499 return FailedPrecondition( 500 "Can only copy array and tuple shaped instructions"); 501 } 502 } 503 504 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( 505 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, 506 ShapeTree<HloInstruction*>* copies_added) { 507 if (instruction->parent() != this) { 508 return FailedPrecondition( 509 "Can't deep copy instruction %s: instruction is not in computation %s", 510 instruction->name().c_str(), name().c_str()); 511 } 512 if (indices_to_copy != nullptr && 513 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { 514 return FailedPrecondition( 515 "Can't deep copy instruction %s: given shape tree of indices to copy " 516 "has incompatible shapes: %s vs. %s", 517 instruction->name().c_str(), 518 ShapeUtil::HumanString(instruction->shape()).c_str(), 519 ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); 520 } 521 522 ShapeIndex index; 523 return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index); 524 } 525 526 ProgramShape HloComputation::ComputeProgramShape() const { 527 ProgramShape program_shape; 528 529 for (auto* param_instruction : param_instructions_) { 530 *program_shape.add_parameters() = param_instruction->shape(); 531 *program_shape.add_parameter_names() = param_instruction->name(); 532 } 533 *program_shape.mutable_result() = root_instruction_->shape(); 534 535 LayoutUtil::ClearLayout(&program_shape); 536 return program_shape; 537 } 538 539 bool HloComputation::operator==(const HloComputation& other) const { 540 std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited; 541 std::function<bool(const HloInstruction*, const HloInstruction*)> eq = 542 [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { 543 // If <a,b> are visited but not identical, the recursion should have 544 // been aborted. So, if <a,b> are visited at this point, they must be 545 // identical. 546 if (visited.count(std::make_pair(a, b)) > 0) { 547 return true; 548 } 549 visited.emplace(a, b); 550 return a->Identical( 551 *b, eq, [](const HloComputation* a, const HloComputation* b) { 552 return *a == *b; 553 }); 554 }; 555 return eq(root_instruction(), other.root_instruction()); 556 } 557 558 Status HloComputation::ReplaceWithNewInstruction( 559 HloInstruction* old_instruction, 560 std::unique_ptr<HloInstruction> new_instruction) { 561 return ReplaceInstruction(old_instruction, 562 AddInstruction(std::move(new_instruction))); 563 } 564 565 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, 566 HloInstruction* new_instruction) { 567 TF_RET_CHECK( 568 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) 569 << ShapeUtil::HumanString(old_instruction->shape()) << " vs " 570 << ShapeUtil::HumanString(new_instruction->shape()); 571 572 VLOG(10) << "transformed " << old_instruction->ToString() << " to " 573 << new_instruction->ToString(); 574 // Try to add metadata for HLO instructions that are created to replace 575 // existing HLO instructions (e.g. during optimizations). The assumption is 576 // that the old instruction and the new instruction would perform the same 577 // function, and that they would be correlated to the same TF op. This might 578 // not always be correct since HLO optimizations can cross TF op boundaries. 579 // But still this seems to be better than nothing. 580 if (new_instruction->metadata().op_name().empty()) { 581 new_instruction->set_metadata(old_instruction->metadata()); 582 } 583 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction)); 584 return RemoveInstructionAndUnusedOperands(old_instruction); 585 } 586 587 std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() 588 const { 589 const std::list<HloInstruction*> all = MakeInstructionPostOrder(); 590 auto result = MakeUnique<HloReachabilityMap>(all); 591 592 std::vector<HloInstruction*> inputs; 593 for (const HloInstruction* hlo : all) { 594 inputs.assign(hlo->operands().begin(), hlo->operands().end()); 595 inputs.insert(inputs.end(), hlo->control_predecessors().begin(), 596 hlo->control_predecessors().end()); 597 result->SetReachabilityToUnion(inputs, hlo); 598 } 599 return result; 600 } 601 602 void HloComputation::UpdateReachabilityThroughInstruction( 603 const HloInstruction* instruction, HloReachabilityMap* reachability_map) { 604 std::queue<const HloInstruction*> worklist; 605 worklist.push(instruction); 606 607 std::vector<HloInstruction*> inputs; 608 609 while (!worklist.empty()) { 610 const HloInstruction* item = worklist.front(); 611 worklist.pop(); 612 613 inputs.assign(item->operands().begin(), item->operands().end()); 614 inputs.insert(inputs.end(), item->control_predecessors().begin(), 615 item->control_predecessors().end()); 616 617 if (reachability_map->SetReachabilityToUnion(inputs, item)) { 618 // Add immediate successors to worklist. 619 for (const HloInstruction* user : item->users()) { 620 worklist.push(user); 621 } 622 for (const HloInstruction* succ : item->control_successors()) { 623 worklist.push(succ); 624 } 625 } 626 } 627 } 628 629 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { 630 std::vector<HloInstruction*> unreachable_roots; 631 for (auto* instruction : instructions()) { 632 if (instruction->user_count() == 0 && 633 instruction->control_successors().empty() && 634 instruction != root_instruction()) { 635 unreachable_roots.push_back(instruction); 636 } 637 } 638 VLOG(3) << "Unreachable roots:" 639 << tensorflow::str_util::Join( 640 unreachable_roots, "\n\t", 641 [](string* out, const HloInstruction* hlo) { 642 tensorflow::strings::StrAppend(out, hlo->ToString()); 643 }); 644 return unreachable_roots; 645 } 646 647 template <typename HloInstructionPtr> 648 Status HloComputation::Accept( 649 DfsHloVisitorBase<HloInstructionPtr>* visitor) const { 650 // Visit unreachable roots. Beware that the visitor might delete the currently 651 // visited root, which would invalidate iterators if the unreachable roots 652 // weren't computed ahead of time. 653 for (HloInstruction* root : CollectUnreachableRoots()) { 654 VLOG(3) << "Traversing unreachable root: " << root->ToString(); 655 // Call FinishVisit only at the end. 656 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); 657 } 658 // Visit the computation root instruction last. 659 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); 660 } 661 662 // Explicit instantiations. 663 template Status HloComputation::Accept(DfsHloVisitor* visitor) const; 664 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; 665 666 Status HloComputation::AcceptWithOperandOrder( 667 DfsHloVisitor* visitor, 668 const HloInstruction::CompareFunction& operand_order) const { 669 // Visit unreachable roots. Beware that the visitor might delete the currently 670 // visited root, which would invalidate iterators if the unreachable roots 671 // weren't computed ahead of time. 672 for (HloInstruction* root : CollectUnreachableRoots()) { 673 TF_RETURN_IF_ERROR( 674 root->AcceptWithOperandOrder(visitor, operand_order, 675 /*call_finish_visit=*/false)); 676 } 677 // Visit the computation root instruction last. 678 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order, 679 /*call_finish_visit=*/true); 680 } 681 682 template <typename HloInstructionPtr> 683 Status HloComputation::AcceptOrdered( 684 DfsHloVisitorBase<HloInstructionPtr>* visitor, 685 const std::vector<const HloInstruction*>& order) const { 686 VLOG(3) << "Accepting visitor with order."; 687 for (HloInstruction* root : CollectUnreachableRoots()) { 688 TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) 689 << root->ToString(); 690 } 691 TF_RET_CHECK(order.size() == instruction_count()); 692 std::unordered_set<const HloInstruction*> visited; 693 for (const HloInstruction* instruction : order) { 694 VLOG(3) << "Visiting ordered: " << instruction->ToString(); 695 TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) 696 << "Instruction " << instruction->name() << " is not in computation " 697 << name(); 698 TF_RET_CHECK(visited.count(instruction) == 0) 699 << "Instruction " << instruction->name() 700 << " appears more than once in order"; 701 HloInstruction* mutable_instruction = 702 const_cast<HloInstruction*>(instruction); 703 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction)); 704 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor)); 705 visitor->SetVisited(*mutable_instruction); 706 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction)); 707 visited.insert(instruction); 708 } 709 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction())); 710 return Status::OK(); 711 } 712 713 // Explicit instantiations. 714 template Status HloComputation::AcceptOrdered( 715 DfsHloVisitor*, const std::vector<const HloInstruction*>&) const; 716 template Status HloComputation::AcceptOrdered( 717 ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const; 718 719 Status HloComputation::Accept( 720 const std::function<Status(HloInstruction*)>& visitor_func) { 721 FunctionVisitor visitor(visitor_func); 722 return this->Accept(&visitor); 723 } 724 725 Status HloComputation::Accept( 726 const std::function<Status(const HloInstruction*)>& visitor_func) const { 727 ConstFunctionVisitor visitor(visitor_func); 728 return this->Accept(&visitor); 729 } 730 731 std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix, 732 HloModule* module) { 733 return CloneWithReplacements( 734 /*replacements=*/std::unordered_map<const HloInstruction*, 735 std::unique_ptr<HloInstruction>>(), 736 module, suffix); 737 } 738 739 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( 740 std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 741 replacements, 742 HloModule* module, const string& suffix) { 743 // Look up instr in the replacements map, and return either the replacement, 744 // or instr, if the replacement isn't present. 745 // 746 // Note: This can return null, indicating that instr should not be present in 747 // the new computation. 748 auto replace = [&](HloInstruction* instr) { 749 auto it = replacements.find(instr); 750 if (it == replacements.end()) { 751 return instr; 752 } 753 return it->second.get(); 754 }; 755 756 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; 757 std::vector<HloInstruction*> postorder; 758 for (HloInstruction* instr : MakeInstructionPostOrder()) { 759 if (HloInstruction* replacement = replace(instr)) { 760 postorder.push_back(replacement); 761 } 762 } 763 764 std::unordered_map<HloInstruction*, HloInstruction*> clone_map; 765 std::vector<std::unique_ptr<HloInstruction>> instructions; 766 std::unique_ptr<HloInstruction> new_instr = nullptr; 767 for (auto instr : postorder) { 768 std::vector<HloInstruction*> new_operands; 769 for (auto operand : instr->operands()) { 770 auto replaced_operand = replace(operand); 771 // If replaced_operand is null, that means 'replacements' asked us not to 772 // include operand in the new computation. But we can't do that, because 773 // operand is used by instr. 774 CHECK_NE(replaced_operand, nullptr) 775 << "replacements map tried to eliminate a used instruction " 776 << operand->ToString() << ", used by " << instr->ToString(); 777 new_operands.push_back(FindOrDie(clone_map, replaced_operand)); 778 } 779 new_instr = 780 instr->CloneWithNewOperands(instr->shape(), new_operands, module); 781 InsertOrDie(&clone_map, instr, new_instr.get()); 782 instructions.push_back(std::move(new_instr)); 783 } 784 Builder builder(name() + "." + suffix); 785 for (auto& instr : instructions) { 786 builder.AddInstruction(std::move(instr)); 787 } 788 auto result = builder.Build( 789 /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); 790 791 // Clone control dependencies. 792 for (auto instr : postorder) { 793 HloInstruction* new_instr = FindOrDie(clone_map, instr); 794 for (auto successor : instr->control_successors()) { 795 auto replaced_successor = replace(successor); 796 797 // successor may not be in clone_map, because it might have been 798 // removed by the replacements map. 799 if (replaced_successor == nullptr) { 800 continue; 801 } 802 803 TF_CHECK_OK(new_instr->AddControlDependencyTo( 804 FindOrDie(clone_map, replaced_successor))); 805 } 806 } 807 808 // We cloned the elements of 'replacements', so they're all going to be 809 // destroyed. HloInstructions need to be detached from their operands before 810 // they're destroyed, otherwise they stick around in the operands' users lists 811 // and cause use-after-frees. 812 for (auto& kv : replacements) { 813 if (std::unique_ptr<HloInstruction>& new_instr = kv.second) { 814 new_instr->DetachFromOperands(); 815 } 816 } 817 818 return result; 819 } 820 821 void HloComputation::UniquifyName(NameUniquer* name_uniquer) { 822 name_ = name_uniquer->GetUniqueName(name_); 823 } 824 825 } // namespace xla 826