1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_computation.h" 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <functional> 21 #include <list> 22 #include <queue> 23 #include <set> 24 #include <sstream> 25 26 #include "absl/algorithm/container.h" 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/container/flat_hash_set.h" 29 #include "absl/memory/memory.h" 30 #include "absl/strings/numbers.h" 31 #include "absl/strings/str_cat.h" 32 #include "absl/strings/str_join.h" 33 #include "tensorflow/compiler/xla/layout_util.h" 34 #include "tensorflow/compiler/xla/map_util.h" 35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/hlo_module.h" 38 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 39 #include "tensorflow/compiler/xla/shape_util.h" 40 #include "tensorflow/compiler/xla/status_macros.h" 41 #include "tensorflow/compiler/xla/types.h" 42 #include "tensorflow/compiler/xla/util.h" 43 #include "tensorflow/core/lib/core/errors.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/platform/logging.h" 46 47 namespace xla { 48 49 using absl::StrCat; 50 51 std::unique_ptr<HloComputation> HloComputation::Builder::Build( 52 HloInstruction* root_instruction) { 53 int parameter_count = 0; 54 for (auto& instruction : instructions_) { 55 if (instruction->opcode() == HloOpcode::kParameter) { 56 parameter_count++; 57 } 58 } 59 // If root_instruction is not specified use the last added instruction. 60 HloInstruction* root = 61 root_instruction ? root_instruction : last_added_instruction_; 62 CHECK_NE(nullptr, root); 63 return absl::WrapUnique(new HloComputation( 64 name_, parameter_count, &instructions_, root, fusion_instruction_)); 65 } 66 67 HloComputation::HloComputation( 68 const string& name, int parameter_count, 69 std::vector<std::unique_ptr<HloInstruction>>* instructions, 70 HloInstruction* root_instruction, HloInstruction* fusion_instruction) 71 : name_(NameUniquer::GetSanitizedName(name)), 72 unique_id_(-1), 73 root_instruction_(root_instruction), 74 fusion_instruction_(fusion_instruction) { 75 param_instructions_.resize(parameter_count, nullptr); 76 bool root_found = false; 77 for (auto& instruction : *instructions) { 78 if (instruction->opcode() == HloOpcode::kParameter) { 79 int64 param_no = instruction->parameter_number(); 80 CHECK(param_no >= 0 && param_no < parameter_count) 81 << "\nERROR: invalid parameter number. Expected [0, " 82 << parameter_count << "), got " << param_no; 83 CHECK(param_instructions_[param_no] == nullptr) 84 << "\nERROR: parameter number " << param_no 85 << " already allocated in this computation"; 86 param_instructions_[param_no] = instruction.get(); 87 } 88 root_found |= instruction.get() == root_instruction_; 89 AddInstructionInternal(std::move(instruction)); 90 } 91 CHECK(root_found) 92 << "\nERROR: root instruction is not present in computation."; 93 } 94 95 HloInstruction* HloComputation::AddInstruction( 96 std::unique_ptr<HloInstruction> instruction) { 97 CHECK(instruction->opcode() != HloOpcode::kParameter) 98 << "Parameter instructions cannot be added to a computation after " 99 << "it has been built"; 100 return AddInstructionInternal(std::move(instruction)); 101 } 102 103 HloInstruction* HloComputation::AddInstructionInternal( 104 std::unique_ptr<HloInstruction> instruction) { 105 if (parent() != nullptr) { 106 instruction->UniquifyName(&parent()->instruction_name_uniquer()); 107 instruction->SetUniqueId(parent()->NewUniqueInstructionId()); 108 } 109 instruction->set_parent(this); 110 HloInstruction* pinst = instruction.get(); 111 instruction_iterators_[pinst] = 112 instructions_.insert(instructions_.end(), std::move(instruction)); 113 return pinst; 114 } 115 116 HloInstruction* HloComputation::AddParameter( 117 std::unique_ptr<HloInstruction> instruction) { 118 CHECK(instruction->opcode() == HloOpcode::kParameter); 119 CHECK(IsFusionComputation()); 120 CHECK(fusion_instruction_->operand_count() == param_instructions_.size()); 121 instruction->set_parent(this); 122 param_instructions_.push_back(instruction.get()); 123 AddInstructionInternal(std::move(instruction)); 124 return instructions_.back().get(); 125 } 126 127 HloInstruction* HloComputation::AddEntryComputationParameter( 128 std::unique_ptr<HloInstruction> instruction) { 129 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); 130 CHECK_EQ(instruction->parameter_number(), num_parameters()); 131 CHECK(parent()->entry_computation() == this); 132 133 HloModuleConfig config = parent()->config(); 134 config.mutable_entry_computation_layout()->add_parameter_layout( 135 ShapeLayout(instruction->shape())); 136 parent()->set_config(config); 137 138 instruction->set_parent(this); 139 param_instructions_.push_back(instruction.get()); 140 AddInstructionInternal(std::move(instruction)); 141 142 return instructions_.back().get(); 143 } 144 145 Status HloComputation::RemoveParameter(int64 param_no) { 146 CHECK_GE(param_no, 0); 147 CHECK_LT(param_no, param_instructions_.size()); 148 CHECK(IsFusionComputation()); 149 HloInstruction* param_instruction = param_instructions_[param_no]; 150 auto param_instruction_iterator = param_instructions_.begin() + param_no; 151 param_instructions_.erase(param_instruction_iterator); 152 // Throw removed fused parameter instruction away. 153 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 154 155 while (param_no < param_instructions_.size()) { 156 param_instruction = param_instructions_[param_no]; 157 HloInstruction* new_instr = 158 AddInstructionInternal(HloInstruction::CreateParameter( 159 param_no, param_instruction->shape(), StrCat("param_", param_no))); 160 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); 161 param_instructions_[param_no] = new_instr; 162 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 163 param_no++; 164 } 165 166 return Status::OK(); 167 } 168 169 Status HloComputation::RemoveUnusedParameters() { 170 CHECK(IsFusionComputation()); 171 int64 removed = 0; 172 for (int64 i = 0; i < param_instructions_.size(); ++i) { 173 HloInstruction* param_instruction = param_instructions_[i]; 174 if (param_instruction->user_count() == 0 && 175 param_instruction != root_instruction()) { 176 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 177 ++removed; 178 continue; 179 } 180 181 if (removed > 0) { 182 const int64 param_no = i - removed; 183 HloInstruction* new_instr = AddInstructionInternal( 184 HloInstruction::CreateParameter(param_no, param_instruction->shape(), 185 StrCat("param_", param_no))); 186 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); 187 param_instructions_[param_no] = new_instr; 188 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); 189 } 190 } 191 param_instructions_.resize(param_instructions_.size() - removed); 192 return Status::OK(); 193 } 194 195 bool HloComputation::IsRemovable(const HloInstruction* instruction) { 196 // If the instruction has control predecessors or successors then we cannot 197 // remove the instruction without violating ordering constraints (added, for 198 // example, to avert interference due to buffer aliasing). 199 if (!instruction->control_predecessors().empty() || 200 !instruction->control_successors().empty()) { 201 return false; 202 } 203 204 if (instruction->opcode() == HloOpcode::kParameter && 205 !IsFusionComputation()) { 206 return false; 207 } 208 209 return true; 210 } 211 212 bool HloComputation::HasSideEffect() const { 213 for (auto* instruction : instructions()) { 214 if (instruction->HasSideEffect()) { 215 return true; 216 } 217 } 218 return false; 219 } 220 221 Status HloComputation::RemoveInstructionAndUnusedOperands( 222 HloInstruction* instruction) { 223 TF_RET_CHECK(root_instruction() != instruction); 224 225 TF_RET_CHECK(instruction->user_count() == 0); 226 TF_RET_CHECK(IsRemovable(instruction)) 227 << "Cannot remove instruction: " << instruction->ToString(); 228 absl::flat_hash_set<HloInstruction*> removed; 229 std::queue<HloInstruction*> worklist; 230 worklist.push(instruction); 231 while (!worklist.empty()) { 232 HloInstruction* item = worklist.front(); 233 worklist.pop(); 234 235 if (removed.contains(item) || item->user_count() != 0 || 236 item == root_instruction() || !IsRemovable(item) || 237 (item->HasSideEffect() && item != instruction)) { 238 continue; 239 } 240 for (int i = 0; i < item->operand_count(); ++i) { 241 worklist.push(item->mutable_operand(i)); 242 } 243 244 TF_RETURN_IF_ERROR(RemoveInstruction(item)); 245 removed.insert(item); 246 } 247 return Status::OK(); 248 } 249 250 Status HloComputation::RemoveInstruction(HloInstruction* instruction) { 251 VLOG(2) << "Removing instruction " << instruction->name() 252 << " from computation " << name(); 253 TF_RET_CHECK(IsRemovable(instruction)) 254 << "cannot remove instruction: " << instruction->ToString(); 255 TF_RET_CHECK(root_instruction() != instruction) 256 << "cannot remove root instruction " << instruction->name(); 257 TF_RET_CHECK(instruction->user_count() == 0) 258 << "instruction " << instruction->name() 259 << " has users and cannot be removed"; 260 TF_RET_CHECK(instruction->control_predecessors().empty()) 261 << "instruction " << instruction->name() 262 << " has control predecessors and cannot be removed"; 263 TF_RET_CHECK(instruction->control_successors().empty()) 264 << "instruction " << instruction->name() 265 << " has control successors and cannot be removed"; 266 267 auto inst_it = instruction_iterators_.find(instruction); 268 TF_RET_CHECK(inst_it != instruction_iterators_.end()); 269 (*inst_it->second)->set_parent(nullptr); 270 instructions_.erase(inst_it->second); 271 instruction_iterators_.erase(inst_it); 272 return Status::OK(); 273 } 274 275 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, 276 bool accept_different_shape) { 277 // The shape of the root (ignoring layout) is an invariant of the computation 278 // for non-fusion cases. 279 if (!IsFusionComputation() && !accept_different_shape) { 280 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), 281 root_instruction_->shape())) 282 << new_root_instruction->shape() << " is incompatible with " 283 << root_instruction_->shape(); 284 } 285 bool root_found = false; 286 for (auto& instruction : instructions_) { 287 if (new_root_instruction == instruction.get()) { 288 root_found = true; 289 break; 290 } 291 } 292 DCHECK(root_found); 293 294 root_instruction_ = new_root_instruction; 295 } 296 297 namespace { 298 299 // Helper which builds a post order of the HLO call graph. 300 void ComputeComputationPostOrder(HloComputation* computation, 301 absl::flat_hash_set<HloComputation*>* visited, 302 std::vector<HloComputation*>* post_order) { 303 if (visited->insert(computation).second) { 304 for (auto* instruction : computation->instructions()) { 305 for (HloComputation* called_computation : 306 instruction->called_computations()) { 307 ComputeComputationPostOrder(called_computation, visited, post_order); 308 } 309 } 310 post_order->push_back(computation); 311 } 312 } 313 314 } // namespace 315 316 void HloComputation::ComputeInstructionPostOrder( 317 const HloComputation::ChannelDependencyGroup& channel_dependency_group, 318 std::vector<HloInstruction*>* post_order, HloInstruction* root, 319 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const { 320 std::vector<HloInstruction*> dfs_stack; 321 dfs_stack.push_back(root); 322 while (!dfs_stack.empty()) { 323 const auto current = dfs_stack.back(); 324 auto it = visited->find(current); 325 if (it != visited->end()) { 326 if (it->second == kVisited) { 327 // Already visited. 328 dfs_stack.pop_back(); 329 continue; 330 } 331 // Visit this node. 332 CHECK_EQ(kVisiting, it->second); 333 dfs_stack.pop_back(); 334 post_order->push_back(current); 335 it->second = kVisited; 336 continue; 337 } 338 339 visited->insert({current, kVisiting}); 340 341 const auto get_channel_id = 342 [](HloInstruction* inst) -> absl::optional<int64> { 343 switch (inst->opcode()) { 344 case HloOpcode::kRecvDone: 345 return inst->channel_id(); 346 case HloOpcode::kAllReduce: 347 return inst->all_reduce_id(); 348 default: 349 return absl::nullopt; 350 } 351 }; 352 353 // When adding a predecessor to the dfs_stack, we need to also add its 354 // associated channel dependencies. 355 const auto add_dfs_stack = [&](HloInstruction* inst) { 356 auto channel_id = get_channel_id(inst); 357 if (channel_id && channel_dependency_group.count(*channel_id)) { 358 auto it = channel_dependency_group.find(*channel_id); 359 for (HloInstruction* cinst : it->second) { 360 dfs_stack.emplace_back(cinst); 361 } 362 } else { 363 dfs_stack.emplace_back(inst); 364 } 365 }; 366 367 const auto add_predecessors = [&](HloInstruction* inst) { 368 // Add the operands to the stack in reverse order so the first operand is 369 // processed first. This will produce a more natural ordering and a nicer 370 // result for things like HLO stringification. 371 const auto& operands = inst->operands(); 372 for (int64 i = operands.size() - 1; i >= 0; --i) { 373 add_dfs_stack(operands[i]); 374 } 375 376 for (HloInstruction* op : inst->control_predecessors()) { 377 add_dfs_stack(op); 378 } 379 }; 380 381 // If the current instruction is a channel instruction, add the dependencies 382 // from all associated instructions of the channel. 383 auto channel_id = get_channel_id(current); 384 if (channel_id && channel_dependency_group.count(*channel_id)) { 385 auto it = channel_dependency_group.find(*channel_id); 386 for (HloInstruction* cinst : it->second) { 387 add_predecessors(cinst); 388 } 389 } else { 390 add_predecessors(current); 391 } 392 } 393 } 394 395 HloComputation::ChannelDependencyGroup 396 HloComputation::ComputeChannelDependencies() const { 397 ChannelDependencyGroup channel_dependency_group; 398 for (const auto& instruction : instructions_) { 399 switch (instruction->opcode()) { 400 case HloOpcode::kSend: 401 case HloOpcode::kRecvDone: 402 channel_dependency_group[instruction->channel_id()].push_back( 403 instruction.get()); 404 break; 405 case HloOpcode::kAllReduce: { 406 auto all_reduce_id = instruction->all_reduce_id(); 407 if (all_reduce_id) { 408 channel_dependency_group[all_reduce_id.value()].push_back( 409 instruction.get()); 410 } 411 break; 412 } 413 default: 414 break; 415 } 416 } 417 return channel_dependency_group; 418 } 419 420 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { 421 auto channel_dependency_group = ComputeChannelDependencies(); 422 std::vector<HloInstruction*> post_order; 423 post_order.reserve(instruction_count()); 424 std::vector<HloInstruction*> trace_instructions; 425 absl::flat_hash_map<HloInstruction*, VisitState> visited; 426 visited.reserve(instruction_count()); 427 for (auto& instruction : instructions_) { 428 if (instruction->opcode() == HloOpcode::kTrace) { 429 // Trace instructions aren't handled by the DFS visitor. Add trace 430 // instructions to the post order at the end (necessarily they have no 431 // users). 432 trace_instructions.push_back(instruction.get()); 433 } else if (instruction->users().empty()) { 434 ComputeInstructionPostOrder(channel_dependency_group, &post_order, 435 instruction.get(), &visited); 436 } 437 } 438 post_order.insert(post_order.end(), trace_instructions.begin(), 439 trace_instructions.end()); 440 CHECK_EQ(instructions_.size(), post_order.size()) 441 << "number of instructions does not match post order size"; 442 return post_order; 443 } 444 445 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() 446 const { 447 absl::flat_hash_set<HloComputation*> visited; 448 std::vector<HloComputation*> post_order; 449 450 // To avoid special handling of this computation, cast away const of 451 // 'this'. 'this' is immediately removed from the post order after 452 // construction. 453 // 454 // TODO(b/78350259): This violates const-correctness, since while the original 455 // computation is not returned, we still retrieve non-const computations from 456 // a const one. Consider also avoiding const for HloComputation, or review XLA 457 // for const-correctness of non-HloInstruction* types like this. 458 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited, 459 &post_order); 460 461 // We don't want to include this computation in the post order. 462 CHECK_EQ(this, post_order.back()); 463 post_order.pop_back(); 464 465 return post_order; 466 } 467 468 string HloComputation::ToString(const HloPrintOptions& options) const { 469 return ToString(options, MakeInstructionPostOrder()); 470 } 471 472 string HloComputation::ToString( 473 const HloPrintOptions& options, 474 absl::Span<const HloInstruction* const> instruction_order) const { 475 CHECK_EQ(instruction_order.size(), instruction_count()); 476 477 std::ostringstream s; 478 for (int i = 0; i < options.indent_amount(); i++) { 479 s << " "; 480 } 481 482 if (!options.is_in_nested_computation()) { 483 if (options.print_percent()) { 484 s << "%"; 485 } 486 s << name() << " "; 487 } 488 489 if (options.print_program_shape()) { 490 s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; 491 } 492 s << "{\n"; 493 { 494 // Print the instructions in this computation. 495 HloPrintOptions new_options = options; 496 new_options.set_indent_amount(options.indent_amount() + 1) 497 .set_is_in_nested_computation(true); 498 CanonicalNameMap name_map; 499 for (const HloInstruction* instruction : instruction_order) { 500 CHECK_EQ(this, instruction->parent()); 501 502 for (int i = 0; i < new_options.indent_amount(); i++) { 503 s << " "; 504 } 505 s << (instruction == root_instruction_ ? "ROOT " : "") 506 << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) 507 << "\n"; 508 } 509 } 510 511 for (int i = 0; i < options.indent_amount(); i++) { 512 s << " "; 513 } 514 s << "}"; 515 return s.str(); 516 } 517 518 HloComputationProto HloComputation::ToProto() const { 519 HloComputationProto proto; 520 CHECK(unique_id_ != -1) 521 << "This computation does not have a valid id. Please make sure the " 522 "computation is inside a module before dumping it."; 523 proto.set_id(unique_id_); 524 proto.set_name(name_); 525 for (const HloInstruction* instruction : MakeInstructionPostOrder()) { 526 HloInstructionProto instruction_proto = instruction->ToProto(); 527 proto.add_instructions()->Swap(&instruction_proto); 528 } 529 proto.set_root_id(root_instruction()->unique_id()); 530 *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); 531 return proto; 532 } 533 534 /* static */ StatusOr<std::unique_ptr<HloComputation>> 535 HloComputation::CreateFromProto( 536 const HloComputationProto& proto, 537 const absl::flat_hash_map<int64, HloComputation*>& computation_map) { 538 absl::flat_hash_map<int64, HloInstruction*> instruction_map; 539 absl::flat_hash_map<HloInstruction*, int64> to_proto_id; 540 std::vector<std::unique_ptr<HloInstruction>> instructions; 541 int64 parameter_count = 0; 542 for (const HloInstructionProto& instruction_proto : proto.instructions()) { 543 TF_ASSIGN_OR_RETURN( 544 std::unique_ptr<HloInstruction> instruction, 545 HloInstruction::CreateFromProto(instruction_proto, instruction_map, 546 computation_map)); 547 if (instruction->opcode() == HloOpcode::kParameter) { 548 parameter_count++; 549 } 550 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); 551 instruction_map[instruction_proto.id()] = instruction.get(); 552 to_proto_id[instruction.get()] = instruction_proto.id(); 553 instructions.push_back(std::move(instruction)); 554 } 555 556 TF_RET_CHECK(proto.root_id() != -1); 557 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); 558 HloInstruction* root = instruction_map.at(proto.root_id()); 559 560 // Sort the instructions in the proto id's order. 561 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a, 562 const std::unique_ptr<HloInstruction>& b) { 563 return to_proto_id[a.get()] < to_proto_id[b.get()]; 564 }); 565 566 TF_RETURN_IF_ERROR([&]() -> Status { 567 std::vector<bool> parameters_seen(parameter_count); 568 int parameters_seen_count = 0; 569 for (auto& instruction : instructions) { 570 if (instruction->opcode() == HloOpcode::kParameter) { 571 int64 param_no = instruction->parameter_number(); 572 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) 573 << "Invalid parameter number. Expected [0, " << parameter_count 574 << "), got " << param_no; 575 TF_RET_CHECK(!parameters_seen[param_no]) 576 << "Parameter number " << param_no 577 << " already allocated in this computation"; 578 parameters_seen[param_no] = true; 579 parameters_seen_count++; 580 } 581 } 582 TF_RET_CHECK(parameters_seen_count == parameter_count) 583 << "Not all parameters in range [0, " << parameter_count 584 << ") were referenced"; 585 return Status::OK(); 586 }()); 587 588 auto computation = absl::WrapUnique( 589 new HloComputation(proto.name(), parameter_count, &instructions, root, 590 /*fusion_instruction=*/nullptr)); 591 computation->unique_id_ = proto.id(); 592 return std::move(computation); 593 } 594 595 void HloComputation::FuseInstructionsInto( 596 absl::Span<HloInstruction* const> instructions_to_fuse, 597 HloInstruction* fusion_instruction) { 598 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); 599 HloInstruction* root = instructions_to_fuse.front(); 600 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction)); 601 if (root == root_instruction()) { 602 set_root_instruction(fusion_instruction); 603 } 604 TF_CHECK_OK(RemoveInstruction(root)); 605 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) { 606 HloInstruction* instruction = instructions_to_fuse[i]; 607 fusion_instruction->FuseInstruction(instruction); 608 if (instruction->user_count() == 0) { 609 TF_CHECK_OK(RemoveInstruction(instruction)); 610 } 611 } 612 } 613 614 HloInstruction* HloComputation::CreateFusionInstruction( 615 absl::Span<HloInstruction* const> instructions_to_fuse, 616 HloInstruction::FusionKind fusion_kind) { 617 HloInstruction* root = instructions_to_fuse.front(); 618 HloInstruction* fusion_instruction = AddInstruction( 619 HloInstruction::CreateFusion(root->shape(), fusion_kind, root)); 620 FuseInstructionsInto(instructions_to_fuse, fusion_instruction); 621 return fusion_instruction; 622 } 623 624 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper( 625 HloInstruction* instruction, ShapeIndex* index, 626 const std::function< 627 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, 628 HloComputation* computation)>& copy_leaf) { 629 if (instruction->shape().IsTuple()) { 630 std::vector<HloInstruction*> elements; 631 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); 632 i++) { 633 HloInstruction* gte = 634 AddInstruction(HloInstruction::CreateGetTupleElement( 635 ShapeUtil::GetTupleElementShape(instruction->shape(), i), 636 instruction, i)); 637 638 index->push_back(i); 639 TF_ASSIGN_OR_RETURN(HloInstruction * element, 640 DeepCopyHelper(gte, index, copy_leaf)); 641 elements.push_back(element); 642 index->pop_back(); 643 } 644 return AddInstruction(HloInstruction::CreateTuple(elements)); 645 } 646 if (instruction->shape().IsToken()) { 647 // Tokens have no on-device representation and cannot be copied. Pass 648 // through transparently. 649 return instruction; 650 } 651 652 // Array shape. 653 TF_RET_CHECK(instruction->shape().IsArray()); 654 return copy_leaf(instruction, *index, this); 655 } 656 657 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( 658 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, 659 ShapeTree<HloInstruction*>* copies_added) { 660 if (instruction->parent() != this) { 661 return FailedPrecondition( 662 "Can't deep copy instruction %s: instruction is not in computation %s", 663 instruction->name(), name()); 664 } 665 if (indices_to_copy != nullptr && 666 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { 667 return FailedPrecondition( 668 "Can't deep copy instruction %s: given shape tree of indices to copy " 669 "has incompatible shapes: %s vs. %s", 670 instruction->name(), ShapeUtil::HumanString(instruction->shape()), 671 ShapeUtil::HumanString(indices_to_copy->shape())); 672 } 673 674 ShapeIndex index; 675 auto copy_leaf = [indices_to_copy, copies_added]( 676 HloInstruction* leaf, const ShapeIndex& leaf_index, 677 HloComputation* computation) { 678 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) { 679 HloInstruction* copy = computation->AddInstruction( 680 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf)); 681 if (copies_added != nullptr) { 682 *copies_added->mutable_element(leaf_index) = copy; 683 } 684 return copy; 685 } 686 // Elements which are not to be copied are passed through 687 // transparently. 688 return leaf; 689 }; 690 return DeepCopyHelper(instruction, &index, copy_leaf); 691 } 692 693 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier( 694 HloInstruction* instruction, 695 const std::function< 696 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, 697 HloComputation* computation)>& copy_leaf) { 698 if (instruction->parent() != this) { 699 return FailedPrecondition( 700 "Can't deep copy instruction %s: instruction is not in computation %s", 701 instruction->name(), name()); 702 } 703 ShapeIndex index; 704 return DeepCopyHelper(instruction, &index, copy_leaf); 705 } 706 707 ProgramShape HloComputation::ComputeProgramShape() const { 708 ProgramShape program_shape; 709 710 for (auto* param_instruction : param_instructions_) { 711 *program_shape.add_parameters() = param_instruction->shape(); 712 *program_shape.add_parameter_names() = param_instruction->name(); 713 } 714 *program_shape.mutable_result() = root_instruction_->shape(); 715 716 return program_shape; 717 } 718 719 bool HloComputation::operator==(const HloComputation& other) const { 720 if (this == &other) { 721 return true; 722 } 723 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>> 724 visited; 725 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist; 726 727 worklist.push_back({root_instruction(), other.root_instruction()}); 728 729 while (!worklist.empty()) { 730 auto pair = worklist.back(); 731 worklist.pop_back(); 732 733 if (visited.contains(pair)) { 734 continue; 735 } 736 visited.emplace(pair); 737 // TODO(b/123082518): Avoid recursively invoking == becasue it may 738 // cause a stack overflow with deeply nested subcomputations. 739 bool identical_ignoring_operands = pair.first->Identical( 740 *pair.second, 741 [](const HloInstruction*, const HloInstruction*) { return true; }, 742 [](const HloComputation* a, const HloComputation* b) { 743 return *a == *b; 744 }); 745 if (!identical_ignoring_operands) { 746 return false; 747 } 748 for (size_t i = 0; i < pair.first->operands().size(); ++i) { 749 worklist.push_back({pair.first->operand(i), pair.second->operand(i)}); 750 } 751 } 752 return true; 753 } 754 755 Status HloComputation::ReplaceWithNewInstruction( 756 HloInstruction* old_instruction, 757 std::unique_ptr<HloInstruction> new_instruction) { 758 return ReplaceInstruction(old_instruction, 759 AddInstruction(std::move(new_instruction))); 760 } 761 762 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, 763 HloInstruction* new_instruction) { 764 TF_RET_CHECK( 765 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) 766 << ShapeUtil::HumanString(old_instruction->shape()) << " vs " 767 << ShapeUtil::HumanString(new_instruction->shape()); 768 769 VLOG(10) << "transformed " << old_instruction->ToString() << " to " 770 << new_instruction->ToString(); 771 // Try to add metadata for HLO instructions that are created to replace 772 // existing HLO instructions (e.g. during optimizations). The assumption is 773 // that the old instruction and the new instruction would perform the same 774 // function, and that they would be correlated to the same TF op. This might 775 // not always be correct since HLO optimizations can cross TF op boundaries. 776 // But still this seems to be better than nothing. 777 if (new_instruction->metadata().op_name().empty()) { 778 new_instruction->set_metadata(old_instruction->metadata()); 779 } 780 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction)); 781 return RemoveInstructionAndUnusedOperands(old_instruction); 782 } 783 784 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { 785 std::vector<HloInstruction*> unreachable_roots; 786 for (auto* instruction : instructions()) { 787 if (instruction->user_count() == 0 && 788 instruction->control_successors().empty() && 789 instruction != root_instruction()) { 790 unreachable_roots.push_back(instruction); 791 } 792 } 793 VLOG(3) << "Unreachable roots:" 794 << absl::StrJoin(unreachable_roots, "\n\t", 795 [](string* out, const HloInstruction* hlo) { 796 absl::StrAppend(out, hlo->ToString()); 797 }); 798 return unreachable_roots; 799 } 800 801 template <typename HloInstructionPtr> 802 Status HloComputation::Accept( 803 DfsHloVisitorBase<HloInstructionPtr>* visitor) const { 804 // Visit unreachable roots. Beware that the visitor might delete the currently 805 // visited root, which would invalidate iterators if the unreachable roots 806 // weren't computed ahead of time. 807 for (HloInstruction* root : CollectUnreachableRoots()) { 808 VLOG(3) << "Traversing unreachable root: " << root->ToString(); 809 // Call FinishVisit only at the end. 810 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); 811 } 812 // Visit the computation root instruction last. 813 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); 814 } 815 816 // Explicit instantiations. 817 template Status HloComputation::Accept(DfsHloVisitor* visitor) const; 818 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; 819 820 Status HloComputation::AcceptWithOperandOrder( 821 DfsHloVisitor* visitor, 822 const HloInstruction::CompareFunction& operand_order) const { 823 // Visit unreachable roots. Beware that the visitor might delete the currently 824 // visited root, which would invalidate iterators if the unreachable roots 825 // weren't computed ahead of time. 826 for (HloInstruction* root : CollectUnreachableRoots()) { 827 TF_RETURN_IF_ERROR( 828 root->AcceptWithOperandOrder(visitor, operand_order, 829 /*call_finish_visit=*/false)); 830 } 831 // Visit the computation root instruction last. 832 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order, 833 /*call_finish_visit=*/true); 834 } 835 836 template <typename HloInstructionPtr> 837 Status HloComputation::AcceptOrdered( 838 DfsHloVisitorBase<HloInstructionPtr>* visitor, 839 absl::Span<HloInstruction* const> order) const { 840 VLOG(3) << "Accepting visitor with order."; 841 for (HloInstruction* root : CollectUnreachableRoots()) { 842 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); 843 } 844 TF_RET_CHECK(order.size() == instruction_count()); 845 absl::flat_hash_set<const HloInstruction*> visited; 846 for (const HloInstruction* instruction : order) { 847 VLOG(3) << "Visiting ordered: " << instruction->ToString(); 848 TF_RET_CHECK(instruction_iterators_.contains(instruction)) 849 << "Instruction " << instruction->name() << " is not in computation " 850 << name(); 851 TF_RET_CHECK(!visited.contains(instruction)) 852 << "Instruction " << instruction->name() 853 << " appears more than once in order"; 854 HloInstruction* mutable_instruction = 855 const_cast<HloInstruction*>(instruction); 856 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction)); 857 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor)); 858 visitor->SetVisited(*mutable_instruction); 859 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction)); 860 visited.insert(instruction); 861 } 862 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction())); 863 return Status::OK(); 864 } 865 866 // Explicit instantiations. 867 template Status HloComputation::AcceptOrdered( 868 DfsHloVisitor*, absl::Span<HloInstruction* const>) const; 869 template Status HloComputation::AcceptOrdered( 870 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const; 871 872 Status HloComputation::Accept( 873 const std::function<Status(HloInstruction*)>& visitor_func) { 874 FunctionVisitor visitor(visitor_func); 875 return this->Accept(&visitor); 876 } 877 878 Status HloComputation::Accept( 879 const std::function<Status(const HloInstruction*)>& visitor_func) const { 880 ConstFunctionVisitor visitor(visitor_func); 881 return this->Accept(&visitor); 882 } 883 884 std::unique_ptr<HloComputation> HloComputation::Clone( 885 const string& suffix, HloCloneContext* context) { 886 return CloneWithReplacements( 887 /*replacements=*/absl::flat_hash_map<const HloInstruction*, 888 std::unique_ptr<HloInstruction>>(), 889 /*extra_parameters=*/{}, context, suffix); 890 } 891 892 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs( 893 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 894 HloCloneContext* context, const string& suffix) { 895 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 896 replacements; 897 replacements.emplace(std::move(r1)); 898 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, 899 context, suffix); 900 } 901 902 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs( 903 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 904 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2, 905 HloCloneContext* context, const string& suffix) { 906 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 907 replacements; 908 replacements.emplace(std::move(r1)); 909 replacements.emplace(std::move(r2)); 910 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, 911 context, suffix); 912 } 913 914 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs( 915 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 916 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2, 917 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3, 918 HloCloneContext* context, const string& suffix) { 919 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 920 replacements; 921 replacements.emplace(std::move(r1)); 922 replacements.emplace(std::move(r2)); 923 replacements.emplace(std::move(r3)); 924 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, 925 context, suffix); 926 } 927 928 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( 929 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 930 replacements, 931 absl::Span<const HloInstruction* const> extra_parameters, 932 HloCloneContext* context, const string& suffix) { 933 std::unique_ptr<HloCloneContext> context_ptr; 934 if (context == nullptr) { 935 context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix); 936 context = context_ptr.get(); 937 } 938 939 // Look up instr in the replacements map, and return either the replacement, 940 // or instr, if the replacement isn't present. 941 // 942 // Note: This can return null, indicating that instr should not be present in 943 // the new computation. 944 auto replace = [&](HloInstruction* instr) { 945 auto it = replacements.find(instr); 946 if (it == replacements.end()) { 947 return instr; 948 } 949 return it->second.get(); 950 }; 951 952 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; 953 954 // We want to do a postorder walk over [replace(i) for i in instructions_]. 955 // We can't reuse MakeInstructionPostOrder() for this, because that will 956 // generate a postorder of plain instructions_, and our replacements may 957 // change the postorder! 958 // 959 // The postorder we want here is simpler than what MakeInstructionPostOrder() 960 // does -- we only care about operand dependencies -- so let's just do it 961 // ourselves. 962 std::vector<HloInstruction*> postorder; 963 absl::flat_hash_map<HloInstruction*, VisitState> visited; 964 for (const auto& instr : instructions_) { 965 std::vector<HloInstruction*> dfs_stack; 966 HloInstruction* new_instr = replace(instr.get()); 967 if (!new_instr) { 968 continue; 969 } 970 dfs_stack.push_back(new_instr); 971 972 while (!dfs_stack.empty()) { 973 auto* cur = dfs_stack.back(); 974 auto it = visited.find(cur); 975 if (it != visited.end()) { 976 dfs_stack.pop_back(); 977 if (it->second == kVisited) { 978 continue; 979 } 980 CHECK_EQ(it->second, kVisiting); 981 postorder.push_back(cur); 982 it->second = kVisited; 983 continue; 984 } 985 986 visited.insert({cur, kVisiting}); 987 for (HloInstruction* operand : cur->operands()) { 988 HloInstruction* new_operand = replace(operand); 989 if (new_operand) { 990 dfs_stack.emplace_back(new_operand); 991 } 992 } 993 } 994 } 995 996 std::vector<std::unique_ptr<HloInstruction>> instructions; 997 // First add the extra parameters to 'instructions'. 998 for (const auto& instr : extra_parameters) { 999 CHECK_EQ(instr->opcode(), HloOpcode::kParameter) 1000 << "Only parameter instructions are allowed in 'extra_parameters'"; 1001 instructions.emplace_back(instr->Clone()); 1002 } 1003 for (auto instr : postorder) { 1004 std::vector<HloInstruction*> new_operands; 1005 for (auto operand : instr->operands()) { 1006 auto replaced_operand = replace(operand); 1007 CHECK_NE(replaced_operand, nullptr) 1008 << "replacements map tried to eliminate a used instruction " 1009 << operand->ToString() << ", used by " << instr->ToString(); 1010 new_operands.push_back(context->GetInstruction(replaced_operand)); 1011 } 1012 instructions.push_back( 1013 instr->CloneWithNewOperands(instr->shape(), new_operands, context)); 1014 } 1015 Builder builder(name() + "." + suffix); 1016 for (auto& instr : instructions) { 1017 builder.AddInstruction(std::move(instr)); 1018 } 1019 auto result = builder.Build( 1020 /*root_instruction=*/context->GetInstruction( 1021 replace(root_instruction()))); 1022 1023 // Clone control dependencies. 1024 for (auto instr : postorder) { 1025 HloInstruction* new_instr = context->GetInstruction(instr); 1026 for (auto successor : instr->control_successors()) { 1027 auto replaced_successor = replace(successor); 1028 // successor may not have been remapped, because it might have been 1029 // removed by the replacements map. 1030 if (replaced_successor != nullptr) { 1031 TF_CHECK_OK(new_instr->AddControlDependencyTo( 1032 context->GetInstruction(replaced_successor))); 1033 } 1034 } 1035 } 1036 context->MapComputation(this, result.get()); 1037 return result; 1038 } 1039 1040 void HloComputation::UniquifyName(NameUniquer* name_uniquer) { 1041 name_ = name_uniquer->GetUniqueName(name_); 1042 } 1043 1044 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { 1045 auto instructions_in_computation = instructions(); 1046 auto it = absl::c_find_if( 1047 instructions_in_computation, 1048 [&](HloInstruction* instr) { return instr->name() == name; }); 1049 return it == instructions_in_computation.end() ? nullptr : *it; 1050 } 1051 1052 } // namespace xla 1053