1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 17 18 #include <algorithm> 19 #include <queue> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/map_util.h" 23 #include "tensorflow/compiler/xla/ptr_util.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/status.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/util.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/platform/logging.h" 35 36 namespace xla { 37 38 using ::tensorflow::strings::StrAppend; 39 using ::tensorflow::strings::StrCat; 40 41 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, 42 bool bitcast_defines_value) 43 : module_(module), 44 ssa_form_(ssa_form), 45 bitcast_defines_value_(bitcast_defines_value), 46 call_graph_(CallGraph::Build(&module)) {} 47 48 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, 49 const ShapeIndex& index) const { 50 const HloValueSet& value_set = GetValueSet(instruction, index); 51 if (value_set.values().size() != 1) { 52 return false; 53 } 54 return value_set.GetUniqueValue().defining_instruction() == instruction; 55 } 56 57 const HloValue& HloDataflowAnalysis::GetValueDefinedAt( 58 const HloInstruction* instruction, const ShapeIndex& index) const { 59 CHECK(ValueIsDefinedAt(instruction, index)); 60 return GetUniqueValueAt(instruction, index); 61 } 62 63 HloValue& HloDataflowAnalysis::GetValueDefinedAt( 64 const HloInstruction* instruction, const ShapeIndex& index) { 65 CHECK(ValueIsDefinedAt(instruction, index)); 66 return GetUniqueValueAt(instruction, index); 67 } 68 69 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, 70 const ShapeIndex& index, 71 bool is_phi) { 72 const int64 value_id = next_value_id_++; 73 auto emplaced = values_.emplace( 74 std::piecewise_construct, std::forward_as_tuple(value_id), 75 std::forward_as_tuple(value_id, instruction, index, is_phi)); 76 CHECK(emplaced.second); 77 78 VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); 79 80 return &emplaced.first->second; 81 } 82 83 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { 84 HloValue& value = values_.at(value_id); 85 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; 86 87 value_ids_to_delete_.push_back(value_id); 88 } 89 90 void HloDataflowAnalysis::DeleteMarkedValues() { 91 #ifndef NDEBUG 92 // Verify that no marked-for-deletion values are in any of the value sets. 93 tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(), 94 value_ids_to_delete_.end()); 95 for (const auto& pair : value_sets_) { 96 const HloInstruction* instruction = pair.first; 97 const InstructionValueSet& instruction_value_set = pair.second; 98 for (const auto& index_value_set : instruction_value_set) { 99 const HloValueSet& value_set = index_value_set.second; 100 for (const HloValue* value : value_set.values()) { 101 DCHECK(!ContainsKey(id_set, value->id())) 102 << "Value " << value->ToShortString() 103 << " marked for deletion, but still exists in value set for " 104 "instruction " 105 << instruction->name(); 106 } 107 } 108 } 109 #endif 110 111 for (HloValue::Id value_id : value_ids_to_delete_) { 112 values_.erase(value_id); 113 } 114 value_ids_to_delete_.clear(); 115 } 116 117 string HloDataflowAnalysis::ToString() const { 118 string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n"); 119 StrAppend(&out, " Instruction value sets:\n"); 120 for (const HloComputation* computation : module_.computations()) { 121 for (const HloInstruction* instruction : computation->instructions()) { 122 StrAppend(&out, " ", instruction->name(), ":\n"); 123 if (ShapeUtil::IsTuple(instruction->shape())) { 124 GetInstructionValueSet(instruction) 125 .ForEachElement([this, &instruction, &out]( 126 const ShapeIndex& index, 127 const HloValueSet& value_set) { 128 StrAppend(&out, " tuple index ", index.ToString(), ":\n"); 129 for (const HloValue* value : value_set.values()) { 130 StrAppend(&out, " ", value->ToShortString(), 131 ValueIsDefinedAt(instruction, index) ? " (def)" : "", 132 "\n"); 133 } 134 }); 135 } else { 136 const HloValueSet& top_level_value_set = 137 GetValueSet(instruction, /*index=*/{}); 138 for (const HloValue* value : top_level_value_set.values()) { 139 StrAppend(&out, " ", value->ToShortString(), 140 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n"); 141 } 142 } 143 } 144 } 145 StrAppend(&out, " HloValues:\n"); 146 for (const HloValue* value : values()) { 147 StrAppend(&out, value->ToString(/*indent=*/4)); 148 } 149 return out; 150 } 151 152 bool HloDataflowAnalysis::Phi( 153 HloInstruction* instruction, 154 tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) { 155 CHECK(ssa_form_); 156 VLOG(4) << "Phi(" << instruction->name() << ")"; 157 VLOG(5) << "instruction value set = " 158 << GetInstructionValueSet(instruction).ToString(); 159 for (const InstructionValueSet* input : inputs) { 160 VLOG(5) << "input value set = " << input->ToString(); 161 } 162 for (const InstructionValueSet* input : inputs) { 163 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); 164 } 165 166 bool changed = false; 167 for (auto& pair : GetInstructionValueSet(instruction)) { 168 const ShapeIndex& index = pair.first; 169 HloValueSet& value_set = pair.second; 170 171 // Positions with phi values should never have more than one value in the 172 // value set. 173 CHECK_LE(value_set.values().size(), 1); 174 const HloValue* current_value = 175 value_set.values().size() == 1 ? value_set.values()[0] : nullptr; 176 177 // Construct a vector of unique value IDs of the inputs. 178 // Don't add value ids where the input is equal to the definition. 179 std::vector<HloValue::Id> input_value_ids; 180 for (const InstructionValueSet* input : inputs) { 181 for (const HloValue* value : input->element(index).values()) { 182 if (value->defining_instruction() == instruction && 183 value->defining_index() == index) { 184 continue; 185 } 186 input_value_ids.push_back(value->id()); 187 } 188 } 189 std::sort(input_value_ids.begin(), input_value_ids.end()); 190 input_value_ids.erase( 191 std::unique(input_value_ids.begin(), input_value_ids.end()), 192 input_value_ids.end()); 193 194 // Remove the existing phi value (if it exists). The phi can be its own 195 // input, for example, in while body parameters where the body passes 196 // through the parameter value. 197 bool current_value_defined_here = 198 (current_value != nullptr && 199 current_value->defining_instruction() == instruction && 200 current_value->defining_index() == index); 201 if (current_value_defined_here) { 202 VLOG(5) << "current_value_defined_here: " << current_value->ToString(); 203 CHECK(current_value->is_phi()); 204 auto it = std::find(input_value_ids.begin(), input_value_ids.end(), 205 current_value->id()); 206 if (it != input_value_ids.end()) { 207 input_value_ids.erase(it); 208 } 209 } 210 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size(); 211 if (input_value_ids.empty()) { 212 // A value set which has at least one element should never have its value 213 // set reduced to zero elements. During dataflow value sets only can go 214 // from empty to non-empty, not the reverse. 215 CHECK_EQ(value_set.values().size(), 0) 216 << "Instruction " << instruction->name() << " at index " << index 217 << " previously had non-empty value set. Value set: " << value_set; 218 } else if (input_value_ids.size() == 1) { 219 // Only a single value reaches this point. There should be no phi, and 220 // this value set should contain this single value. 221 const HloValue& new_value = GetValue(input_value_ids[0]); 222 if (current_value == nullptr) { 223 value_set.Clear(); 224 value_set.AddValue(&new_value); 225 changed = true; 226 } else if (current_value != &new_value) { 227 if (current_value_defined_here) { 228 // Remove the existing phi. 229 MarkValueForDeletion(current_value->id()); 230 } 231 value_set.Clear(); 232 value_set.AddValue(&new_value); 233 changed = true; 234 } 235 } else { 236 // Multiple distinct values reach this point. A phi value is 237 // necessary. 238 CHECK_GT(input_value_ids.size(), 1); 239 if (current_value == nullptr || 240 !(current_value->is_phi() && current_value_defined_here)) { 241 value_set.Clear(); 242 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); 243 changed = true; 244 } 245 } 246 } 247 return changed; 248 } 249 250 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const { 251 return values_.at(value_id); 252 } 253 254 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) { 255 return values_.at(value_id); 256 } 257 258 const HloValueSet& HloDataflowAnalysis::GetValueSet( 259 const HloInstruction* instruction, const ShapeIndex& index) const { 260 return GetInstructionValueSet(instruction).element(index); 261 } 262 263 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction, 264 const ShapeIndex& index) { 265 return *GetInstructionValueSet(instruction).mutable_element(index); 266 } 267 268 const HloValueSet& HloDataflowAnalysis::GetValueSet( 269 const HloPosition& position) const { 270 return GetValueSet(position.instruction, position.index); 271 } 272 273 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) { 274 return GetValueSet(position.instruction, position.index); 275 } 276 277 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { 278 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); 279 const InstructionValueSet& operand_set = 280 GetInstructionValueSet(bitcast->operand(0)); 281 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast); 282 if (!bitcast_defines_value_ && operand_set != bitcast_set) { 283 bitcast_set = operand_set; 284 return true; 285 } 286 return false; 287 } 288 289 bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { 290 CHECK_EQ(slice->opcode(), HloOpcode::kSlice); 291 if (!slice->IsInPlaceSlice()) { 292 return false; 293 } 294 // If this slice is lowered to an in-place version, then it forwards the 295 // operand value to the output. 296 const InstructionValueSet& operand_set = 297 GetInstructionValueSet(slice->operand(0)); 298 InstructionValueSet& slice_set = GetInstructionValueSet(slice); 299 if (operand_set != slice_set) { 300 slice_set = operand_set; 301 return true; 302 } 303 return false; 304 } 305 306 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { 307 CHECK_EQ(send->opcode(), HloOpcode::kSend); 308 bool changed = false; 309 // Send forwards the operand value to the output tuple at {0}. 310 for (auto& pair : GetInstructionValueSet(send->operand(0))) { 311 const ShapeIndex& operand_index = pair.first; 312 const HloValueSet& operand_value_set = pair.second; 313 314 ShapeIndex index = {0}; 315 for (int64 i : operand_index) { 316 index.push_back(i); 317 } 318 319 HloValueSet& value_set = GetValueSet(send, index); 320 if (value_set != operand_value_set) { 321 value_set = operand_value_set; 322 changed = true; 323 } 324 } 325 return changed; 326 } 327 328 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { 329 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); 330 bool changed = false; 331 // RecvDone forwards the operand value at {0} to the output. 332 for (auto& pair : GetInstructionValueSet(recv_done)) { 333 ShapeIndex& index = pair.first; 334 HloValueSet& value_set = pair.second; 335 336 ShapeIndex operand_index = {0}; 337 for (int64 i : index) { 338 operand_index.push_back(i); 339 } 340 341 const HloValueSet& operand_value_set = 342 GetValueSet(recv_done->operand(0), operand_index); 343 if (value_set != operand_value_set) { 344 value_set = operand_value_set; 345 changed = true; 346 } 347 } 348 return changed; 349 } 350 351 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { 352 CHECK_EQ(call->opcode(), HloOpcode::kCall); 353 InstructionValueSet& value_set = GetInstructionValueSet(call); 354 InstructionValueSet& root_value_set = 355 GetInstructionValueSet(call->to_apply()->root_instruction()); 356 if (value_set != root_value_set) { 357 value_set = root_value_set; 358 return true; 359 } 360 return false; 361 } 362 363 bool HloDataflowAnalysis::UpdateConditionalValueSet( 364 HloInstruction* conditional) { 365 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); 366 std::vector<const InstructionValueSet*> inputs = { 367 &GetInstructionValueSet( 368 conditional->true_computation()->root_instruction()), 369 &GetInstructionValueSet( 370 conditional->false_computation()->root_instruction())}; 371 // A phi-node is not defined for a kConditional instruction even though it 372 // represents a join point. This is because the current approach is to define 373 // a phi-node only for kWhile to account for the dataflow through back-edges 374 // and deal with the ambiguity in other cases. 375 return GetInstructionValueSet(conditional).AssignUnionOf(inputs); 376 } 377 378 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { 379 CHECK_EQ(copy->opcode(), HloOpcode::kCopy); 380 bool changed = false; 381 for (auto& pair : GetInstructionValueSet(copy)) { 382 const ShapeIndex& index = pair.first; 383 if (index.empty()) { 384 // kCopy shallow copies and thus defines the top-level value so nothing to 385 // update. 386 continue; 387 } 388 389 HloValueSet& value_set = pair.second; 390 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index); 391 if (value_set != operand_value_set) { 392 value_set = operand_value_set; 393 changed = true; 394 } 395 } 396 return changed; 397 } 398 399 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { 400 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); 401 bool changed = false; 402 // The GetTupleElement instruction forwards the values from the specified 403 // tuple element. 404 for (auto& pair : GetInstructionValueSet(gte)) { 405 const ShapeIndex& index = pair.first; 406 HloValueSet& value_set = pair.second; 407 408 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex 409 // with the tuple element number prefixed. 410 ShapeIndex operand_index = {gte->tuple_index()}; 411 for (int64 i : index) { 412 operand_index.push_back(i); 413 } 414 415 HloValueSet& operand_value_set = 416 GetValueSet(gte->operand(0), operand_index); 417 if (value_set != operand_value_set) { 418 value_set = operand_value_set; 419 changed = true; 420 } 421 } 422 return changed; 423 } 424 425 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { 426 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter); 427 const CallGraphNode& call_graph_node = 428 call_graph_->GetNode(parameter->parent()); 429 430 // Subcomputations called in a parallel context (eg, map) do not have dataflow 431 // from the caller operands. 432 if (call_graph_node.context() == CallContext::kParallel || 433 call_graph_node.caller_callsites().empty()) { 434 return false; 435 } 436 CHECK_EQ(call_graph_node.context(), CallContext::kSequential); 437 438 std::vector<const InstructionValueSet*> inputs; 439 bool need_phi = false; 440 for (const CallSite& callsite : call_graph_node.caller_callsites()) { 441 if (callsite.instruction()->opcode() == HloOpcode::kCall) { 442 // The operand values of a call instruction are forwarded to the 443 // respective parameter instruction of the subcomputation. 444 inputs.push_back(&GetInstructionValueSet( 445 callsite.instruction()->operand(parameter->parameter_number()))); 446 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { 447 // In a while instruction, the while operand (ie, the init value) and the 448 // backedge are dataflow inputs to the parameter instruction. This is the 449 // case for parameters of both the body and condition computations. 450 CHECK_EQ(parameter->parameter_number(), 0); 451 inputs.push_back( 452 &GetInstructionValueSet(callsite.instruction()->operand(0))); 453 // If the parameter *is* the root, then don't consider it's current state 454 // (InstructionValueSet) as we are recomputing its current 455 // state. Otherwise, the parameter state would never be updated. 456 if (parameter != 457 callsite.instruction()->while_body()->root_instruction()) { 458 inputs.push_back(&GetInstructionValueSet( 459 callsite.instruction()->while_body()->root_instruction())); 460 } 461 need_phi = true; 462 } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { 463 CHECK_EQ(parameter->parameter_number(), 0); 464 auto conditional = callsite.instruction(); 465 // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is 466 // the argument to the true computation and operand 2 is the argument to 467 // the false computation. 468 // 469 // If the parameter belongs to conditional's true computation, then 470 // operand 1 is forwarded to this parameter instruction. If the parameter 471 // belongs to conditional's false computation, then operand 2 is forwarded 472 // to this parameter instruction. 473 if (parameter->parent() == conditional->true_computation()) { 474 inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); 475 } else { 476 CHECK_EQ(parameter->parent(), conditional->false_computation()); 477 inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); 478 } 479 need_phi = true; 480 } else { 481 LOG(FATAL) << "CallContext::kSequential computations should only be " 482 "called from call, while, or conditional instructions"; 483 } 484 } 485 486 if (ssa_form_ && need_phi) { 487 return Phi(parameter, inputs); 488 } else { 489 return GetInstructionValueSet(parameter).AssignUnionOf(inputs); 490 } 491 } 492 493 bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) { 494 CHECK_EQ(select->opcode(), HloOpcode::kSelect); 495 // A phi value is not defined at a kSelect instruction because kSelect does 496 // not create a new value. Rather it forwards a value from its operands. This 497 // contrasts with kWhile instruction (which does define a phi value) which has 498 // in-place update semantics. 499 bool changed = false; 500 for (auto& pair : GetInstructionValueSet(select)) { 501 const ShapeIndex& index = pair.first; 502 if (index.empty()) { 503 // kSelect copies (not forwards) the top-level value. 504 continue; 505 } 506 HloValueSet& value_set = pair.second; 507 changed |= 508 value_set.AssignUnionOf({&GetValueSet(select->operand(1), index), 509 &GetValueSet(select->operand(2), index)}); 510 } 511 return changed; 512 } 513 514 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { 515 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); 516 bool changed = false; 517 for (int64 i = 0; i < tuple->operands().size(); ++i) { 518 // Copy the value set(s) of each operand into the respective position in the 519 // kTuple instruction's value sets. 520 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) { 521 const ShapeIndex& operand_index = pair.first; 522 HloValueSet& operand_value_set = pair.second; 523 524 ShapeIndex index = {i}; 525 for (int64 op_index : operand_index) { 526 index.push_back(op_index); 527 } 528 HloValueSet& value_set = GetValueSet(tuple, index); 529 530 if (value_set != operand_value_set) { 531 value_set = operand_value_set; 532 changed = true; 533 } 534 } 535 } 536 return changed; 537 } 538 539 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { 540 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); 541 std::vector<const InstructionValueSet*> inputs = { 542 &GetInstructionValueSet(xla_while->while_body()->root_instruction()), 543 &GetInstructionValueSet(xla_while->operand(0))}; 544 if (ssa_form_) { 545 return Phi(xla_while, inputs); 546 } else { 547 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs); 548 } 549 } 550 551 bool HloDataflowAnalysis::UpdateInstructionValueSet( 552 HloInstruction* instruction) { 553 // Recompute from operands. 554 switch (instruction->opcode()) { 555 case HloOpcode::kBitcast: 556 return UpdateBitcastValueSet(instruction); 557 case HloOpcode::kSlice: 558 return UpdateSliceValueSet(instruction); 559 case HloOpcode::kCopy: 560 return UpdateCopyValueSet(instruction); 561 case HloOpcode::kGetTupleElement: 562 return UpdateGetTupleElementValueSet(instruction); 563 case HloOpcode::kSelect: 564 return UpdateSelectValueSet(instruction); 565 case HloOpcode::kTuple: 566 return UpdateTupleValueSet(instruction); 567 case HloOpcode::kParameter: 568 return UpdateParameterValueSet(instruction); 569 case HloOpcode::kCall: 570 return UpdateCallValueSet(instruction); 571 case HloOpcode::kWhile: 572 return UpdateWhileValueSet(instruction); 573 case HloOpcode::kSend: 574 return UpdateSendValueSet(instruction); 575 case HloOpcode::kRecvDone: 576 return UpdateRecvDoneValueSet(instruction); 577 case HloOpcode::kConditional: 578 return UpdateConditionalValueSet(instruction); 579 default: 580 // Instruction does not forward HloValues (it defines all values in its 581 // output). No update is necessary. 582 return false; 583 } 584 } 585 586 void HloDataflowAnalysis::Propagate() { 587 std::queue<HloInstruction*> worklist; 588 tensorflow::gtl::FlatSet<HloInstruction*> workset; 589 auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { 590 if (workset.insert(instruction).second) { 591 worklist.push(instruction); 592 } 593 }; 594 595 for (HloComputation* computation : module_.computations()) { 596 for (HloInstruction* instruction : computation->instructions()) { 597 add_to_worklist(instruction); 598 } 599 } 600 601 while (!worklist.empty()) { 602 HloInstruction* instruction = worklist.front(); 603 worklist.pop(); 604 workset.erase(workset.find(instruction)); 605 606 VLOG(3) << "Worklist top: " << instruction->name(); 607 VLOG(3) << ToString(); 608 609 if (!UpdateInstructionValueSet(instruction)) { 610 // No change to the instruction's value set. 611 VLOG(4) << "No change."; 612 continue; 613 } 614 615 VLOG(4) << "New value set for " << instruction->name() << ": " 616 << GetInstructionValueSet(instruction); 617 618 // Instruction value was updated. Add users to work list if we haven't 619 // already. 620 for (HloInstruction* user : instruction->users()) { 621 add_to_worklist(user); 622 623 // If user sequentially calls a computation, then the respective 624 // parameter(s) of the computation need to be updated. 625 if (user->opcode() == HloOpcode::kConditional) { 626 // If operand 0 is the use of instruction, then no parameters need to be 627 // updated, since that is the predicate of the conditional. 628 // If operand 1 is the use of instruction, then the true_computation's 629 // parameter need to be updated. 630 // If operand 2 is the use of instruction, then the false_computation's 631 // parameter need to be updated. 632 // 633 // Note that the same instruction can be used in both operand 1 and 634 // operand 2. 635 if (user->operand(1) == instruction) { 636 add_to_worklist(user->true_computation()->parameter_instruction(0)); 637 } 638 if (user->operand(2) == instruction) { 639 add_to_worklist(user->false_computation()->parameter_instruction(0)); 640 } 641 } else { 642 for (HloComputation* called_computation : user->called_computations()) { 643 const CallGraphNode& call_graph_node = 644 call_graph_->GetNode(called_computation); 645 if (call_graph_node.context() == CallContext::kSequential) { 646 for (int64 operand_number : user->OperandIndices(instruction)) { 647 add_to_worklist( 648 called_computation->parameter_instruction(operand_number)); 649 } 650 } 651 } 652 } 653 } 654 655 // If instruction is a root instruction, then propagate out to any calling 656 // instruction and across any while backedge. 657 if (instruction == instruction->parent()->root_instruction()) { 658 const CallGraphNode& call_graph_node = 659 call_graph_->GetNode(instruction->parent()); 660 for (const CallSite& callsite : call_graph_node.caller_callsites()) { 661 if ((callsite.instruction()->opcode() == HloOpcode::kCall) || 662 (callsite.instruction()->opcode() == HloOpcode::kConditional)) { 663 add_to_worklist(callsite.instruction()); 664 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { 665 // Add the while itself, and the body and condition parameters. 666 add_to_worklist(callsite.instruction()); 667 add_to_worklist( 668 callsite.instruction()->while_body()->parameter_instruction(0)); 669 add_to_worklist( 670 callsite.instruction()->while_condition()->parameter_instruction( 671 0)); 672 } 673 } 674 } 675 } 676 } 677 678 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( 679 const HloInstruction* instruction) const { 680 return value_sets_.at(instruction); 681 } 682 683 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( 684 const HloInstruction* instruction) { 685 return value_sets_.at(instruction); 686 } 687 688 Status HloDataflowAnalysis::InitializeInstructionValueSets() { 689 for (const HloComputation* computation : module_.computations()) { 690 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); 691 for (HloInstruction* instruction : computation->instructions()) { 692 // Create an empty shape tree. 693 value_sets_.emplace(std::piecewise_construct, 694 std::forward_as_tuple(instruction), 695 std::forward_as_tuple(instruction->shape())); 696 697 // Lambda to set the value set to define all values in the output of the 698 // instruction. 699 auto define_all_values = [this, &instruction](bool is_phi = false) { 700 for (auto& pair : GetInstructionValueSet(instruction)) { 701 const ShapeIndex& index = pair.first; 702 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); 703 GetValueSet(instruction, index).AddValue(value); 704 } 705 }; 706 707 // Lambda to set the value set to define only the top-level buffer in the 708 // output of the instruction. Any other values flow from the operands of 709 // the instruction (or from cross-computation dataflow). 710 auto define_top_level_only = [this, &instruction]() { 711 HloValue* value = 712 NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false); 713 GetValueSet(instruction, /*index=*/{}).AddValue(value); 714 }; 715 716 // Lambda to set the value set at the given index of the output. 717 auto define_value_at = [this, &instruction](const ShapeIndex& index) { 718 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); 719 GetValueSet(instruction, index).AddValue(value); 720 }; 721 722 switch (instruction->opcode()) { 723 case HloOpcode::kBitcast: 724 if (bitcast_defines_value_) { 725 define_all_values(); 726 } 727 break; 728 case HloOpcode::kSlice: 729 if (!instruction->IsInPlaceSlice()) { 730 define_all_values(); 731 } 732 break; 733 case HloOpcode::kWhile: 734 case HloOpcode::kCall: 735 case HloOpcode::kConditional: 736 case HloOpcode::kGetTupleElement: 737 // These instructions define no values. The values in their output 738 // flow from their operands or from cross computation dataflow. 739 break; 740 case HloOpcode::kParameter: 741 if (call_graph_node.context() == CallContext::kBoth) { 742 // We do not support a subcomputation that is called from both a 743 // parallel and sequential context. In this case, the parameter 744 // would both define a value and propagate a value from its 745 // caller. This limitation is not really a problem because the call 746 // graph is typically flattened. 747 return Unimplemented( 748 "Computation %s is called in both a parallel (eg, kMap) and " 749 "sequential (eg, kCall) context", 750 computation->name().c_str()); 751 } 752 if (call_graph_node.caller_callsites().empty() || 753 call_graph_node.context() == CallContext::kParallel) { 754 // Parameters of computations called in a parallel context (eg, map 755 // and reduce) as well as parameters of dead computations define all 756 // values in their output. Otherwise the values of the parameter 757 // come from the caller (eg, operands to the kCall instruction). 758 define_all_values(); 759 } 760 break; 761 case HloOpcode::kCopy: 762 case HloOpcode::kSelect: 763 case HloOpcode::kTuple: 764 // These instructions only define their top-level values. Any other 765 // values flow from their operands. 766 define_top_level_only(); 767 break; 768 case HloOpcode::kRecvDone: 769 // RecvDone aliases its input tuple element {0}, therefore does not 770 // define any values. 771 break; 772 case HloOpcode::kSend: 773 // Send produces a tuple of {aliased operand, U32 context}, therefore 774 // only defines the top-level tuple and the tuple element at {1}. 775 define_value_at(/*index=*/{}); 776 define_value_at(/*index=*/{1}); 777 break; 778 default: 779 define_all_values(); 780 break; 781 } 782 } 783 } 784 785 return Status::OK(); 786 } 787 788 /* static */ 789 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( 790 const HloModule& module, bool ssa_form, bool bitcast_defines_value) { 791 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); 792 XLA_VLOG_LINES(2, module.ToString()); 793 794 auto dataflow_analysis = WrapUnique( 795 new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); 796 797 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); 798 dataflow_analysis->Propagate(); 799 800 // Delete all values marked for deletion. 801 dataflow_analysis->DeleteMarkedValues(); 802 803 // Gather and set all non-definition positions of all values. Value deletion 804 // is rare, so just use a vector indexed by Value::Id rather than a map from 805 // Value::Id to positions. There should be very few holes in the vector, and 806 // lookup is faster. 807 std::vector<std::vector<HloPosition>> value_positions( 808 dataflow_analysis->next_value_id_); 809 for (const HloComputation* computation : module.computations()) { 810 for (HloInstruction* instruction : computation->instructions()) { 811 for (const auto& pair : 812 dataflow_analysis->GetInstructionValueSet(instruction)) { 813 const ShapeIndex& index = pair.first; 814 const HloValueSet& value_set = pair.second; 815 for (const HloValue* value : value_set.values()) { 816 if (value->defining_instruction() != instruction) { 817 value_positions[value->id()].push_back( 818 HloPosition{instruction, index}); 819 } 820 } 821 } 822 } 823 } 824 for (auto& pair : dataflow_analysis->values_) { 825 HloValue::Id value_id = pair.first; 826 HloValue& value = pair.second; 827 value.SetPositionsAndComputeUses(value_positions[value_id]); 828 } 829 830 // Construct vector of values. 831 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); 832 for (auto& pair : dataflow_analysis->values_) { 833 dataflow_analysis->values_vector_.push_back(&pair.second); 834 } 835 std::sort(dataflow_analysis->values_vector_.begin(), 836 dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); 837 838 TF_DCHECK_OK(dataflow_analysis->Verify()); 839 840 XLA_VLOG_LINES(1, dataflow_analysis->ToString()); 841 842 return std::move(dataflow_analysis); 843 } 844 845 Status HloDataflowAnalysis::Verify() const { 846 // Verify each HloValue appears in the value sets that the value's positions() 847 // indicate. 848 for (const HloValue* value : values()) { 849 for (const HloPosition& position : value->positions()) { 850 const HloValueSet& value_set = GetValueSet(position); 851 TF_RET_CHECK(std::find(value_set.values().begin(), 852 value_set.values().end(), 853 value) != value_set.values().end()) 854 << "Value set at position " << position << " does not contain value " 855 << value->ToShortString(); 856 } 857 } 858 859 // For each value in each value set, verify that the value set's position 860 // appears in the value's positions(). 861 for (const auto& computation : module_.computations()) { 862 for (const auto& instruction : computation->instructions()) { 863 for (const auto& pair : GetInstructionValueSet(instruction)) { 864 const ShapeIndex& index = pair.first; 865 const HloValueSet& value_set = pair.second; 866 const HloPosition position{instruction, index}; 867 for (const HloValue* value : value_set.values()) { 868 TF_RET_CHECK(std::find(value->positions().begin(), 869 value->positions().end(), 870 position) != value->positions().end()) 871 << "Value set at position " << position 872 << " unexpectedly contains value " << value->ToShortString(); 873 } 874 } 875 } 876 } 877 878 return Status::OK(); 879 } 880 881 } // namespace xla 882