     16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
     18 #include <algorithm>
     19 #include <queue>
     20 #include <vector>
     22 #include "tensorflow/compiler/xla/map_util.h"
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/status.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/platform/logging.h"
     36 namespace xla {
     38 using ::tensorflow::strings::StrAppend;
     39 using ::tensorflow::strings::StrCat;
     41 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
     42                                          bool bitcast_defines_value)
     43     : module_(module),
     44       ssa_form_(ssa_form),
     45       bitcast_defines_value_(bitcast_defines_value),
     46       call_graph_(CallGraph::Build(&module)) {}
     48 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
     49                                            const ShapeIndex& index) const {
     50   const HloValueSet& value_set = GetValueSet(instruction, index);
     51   if (value_set.values().size() != 1) {
     52     return false;
     53   }
     54   return value_set.GetUniqueValue().defining_instruction() == instruction;
     55 }
     57 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     58     const HloInstruction* instruction, const ShapeIndex& index) const {
     59   CHECK(ValueIsDefinedAt(instruction, index));
     60   return GetUniqueValueAt(instruction, index);
     61 }
     63 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     64     const HloInstruction* instruction, const ShapeIndex& index) {
     65   CHECK(ValueIsDefinedAt(instruction, index));
     66   return GetUniqueValueAt(instruction, index);
     67 }
     69 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
     70                                            const ShapeIndex& index,
     71                                            bool is_phi) {
     72   const int64 value_id = next_value_id_++;
     73   auto emplaced = values_.emplace(
     74       std::piecewise_construct, std::forward_as_tuple(value_id),
     75       std::forward_as_tuple(value_id, instruction, index, is_phi));
     76   CHECK(emplaced.second);
     78   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
     80   return &emplaced.first->second;
     81 }
     83 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
     84   HloValue& value = values_.at(value_id);
     85   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
     87   value_ids_to_delete_.push_back(value_id);
     88 }
     90 void HloDataflowAnalysis::DeleteMarkedValues() {
     91 #ifndef NDEBUG
     92   // Verify that no marked-for-deletion values are in any of the value sets.
     93   tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
     94                                                 value_ids_to_delete_.end());
     95   for (const auto& pair : value_sets_) {
     96     const HloInstruction* instruction = pair.first;
     97     const InstructionValueSet& instruction_value_set = pair.second;
     98     for (const auto& index_value_set : instruction_value_set) {
     99       const HloValueSet& value_set = index_value_set.second;
    100       for (const HloValue* value : value_set.values()) {
    101         DCHECK(!ContainsKey(id_set, value->id()))
    102             << "Value " << value->ToShortString()
    103             << " marked for deletion, but still exists in value set for "
    104                "instruction "
    105             << instruction->name();
    106       }
    107     }
    108   }
    109 #endif
    111   for (HloValue::Id value_id : value_ids_to_delete_) {
    112     values_.erase(value_id);
    113   }
    114   value_ids_to_delete_.clear();
    115 }
    117 string HloDataflowAnalysis::ToString() const {
    118   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
    119   StrAppend(&out, "  Instruction value sets:\n");
    120   for (const HloComputation* computation : module_.computations()) {
    121     for (const HloInstruction* instruction : computation->instructions()) {
    122       StrAppend(&out, "    ", instruction->name(), ":\n");
    123       if (ShapeUtil::IsTuple(instruction->shape())) {
    124         GetInstructionValueSet(instruction)
    125             .ForEachElement([this, &instruction, &out](
    126                                 const ShapeIndex& index,
    127                                 const HloValueSet& value_set) {
    128               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
    129               for (const HloValue* value : value_set.values()) {
    130                 StrAppend(&out, "        ", value->ToShortString(),
    131                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
    132                           "\n");
    133               }
    134             });
    135       } else {
    136         const HloValueSet& top_level_value_set =
    137             GetValueSet(instruction, /*index=*/{});
    138         for (const HloValue* value : top_level_value_set.values()) {
    139           StrAppend(&out, "      ", value->ToShortString(),
    140                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
    141         }
    142       }
    143     }
    144   }
    145   StrAppend(&out, "  HloValues:\n");
    146   for (const HloValue* value : values()) {
    147     StrAppend(&out, value->ToString(/*indent=*/4));
    148   }
    149   return out;
    150 }
    152 bool HloDataflowAnalysis::Phi(
    153     HloInstruction* instruction,
    154     tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
    155   CHECK(ssa_form_);
    156   VLOG(4) << "Phi(" << instruction->name() << ")";
    157   VLOG(5) << "instruction value set = "
    158           << GetInstructionValueSet(instruction).ToString();
    159   for (const InstructionValueSet* input : inputs) {
    160     VLOG(5) << "input value set = " << input->ToString();
    161   }
    162   for (const InstructionValueSet* input : inputs) {
    163     DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
    164   }
    166   bool changed = false;
    167   for (auto& pair : GetInstructionValueSet(instruction)) {
    168     const ShapeIndex& index = pair.first;
    169     HloValueSet& value_set = pair.second;
    171     // Positions with phi values should never have more than one value in the
    172     // value set.
    173     CHECK_LE(value_set.values().size(), 1);
    174     const HloValue* current_value =
    175         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
    177     // Construct a vector of unique value IDs of the inputs.
    178     // Don't add value ids where the input is equal to the definition.
    179     std::vector<HloValue::Id> input_value_ids;
    180     for (const InstructionValueSet* input : inputs) {
    181       for (const HloValue* value : input->element(index).values()) {
    182         if (value->defining_instruction() == instruction &&
    183             value->defining_index() == index) {
    184           continue;
    185         }
    186         input_value_ids.push_back(value->id());
    187       }
    188     }
    189     std::sort(input_value_ids.begin(), input_value_ids.end());
    190     input_value_ids.erase(
    191         std::unique(input_value_ids.begin(), input_value_ids.end()),
    192         input_value_ids.end());
    194     // Remove the existing phi value (if it exists). The phi can be its own
    195     // input, for example, in while body parameters where the body passes
    196     // through the parameter value.
    197     bool current_value_defined_here =
    198         (current_value != nullptr &&
    199          current_value->defining_instruction() == instruction &&
    200          current_value->defining_index() == index);
    201     if (current_value_defined_here) {
    202       VLOG(5) << "current_value_defined_here: " << current_value->ToString();
    203       CHECK(current_value->is_phi());
    204       auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
    205                           current_value->id());
    206       if (it != input_value_ids.end()) {
    207         input_value_ids.erase(it);
    208       }
    209     }
    210     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
    211     if (input_value_ids.empty()) {
    212       // A value set which has at least one element should never have its value
    213       // set reduced to zero elements. During dataflow value sets only can go
    214       // from empty to non-empty, not the reverse.
    215       CHECK_EQ(value_set.values().size(), 0)
    216           << "Instruction " << instruction->name() << " at index " << index
    217           << " previously had non-empty value set. Value set: " << value_set;
    218     } else if (input_value_ids.size() == 1) {
    219       // Only a single value reaches this point. There should be no phi, and
    220       // this value set should contain this single value.
    221       const HloValue& new_value = GetValue(input_value_ids[0]);
    222       if (current_value == nullptr) {
    223         value_set.Clear();
    224         value_set.AddValue(&new_value);
    225         changed = true;
    226       } else if (current_value != &new_value) {
    227         if (current_value_defined_here) {
    228           // Remove the existing phi.
    229           MarkValueForDeletion(current_value->id());
    230         }
    231         value_set.Clear();
    232         value_set.AddValue(&new_value);
    233         changed = true;
    234       }
    235     } else {
    236       // Multiple distinct values reach this point. A phi value is
    237       // necessary.
    238       CHECK_GT(input_value_ids.size(), 1);
    239       if (current_value == nullptr ||
    240           !(current_value->is_phi() && current_value_defined_here)) {
    241         value_set.Clear();
    242         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
    243         changed = true;
    244       }
    245     }
    246   }
    247   return changed;
    248 }
    250 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
    251   return values_.at(value_id);
    252 }
    254 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
    255   return values_.at(value_id);
    256 }
    258 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    259     const HloInstruction* instruction, const ShapeIndex& index) const {
    260   return GetInstructionValueSet(instruction).element(index);
    261 }
    263 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
    264                                               const ShapeIndex& index) {
    265   return *GetInstructionValueSet(instruction).mutable_element(index);
    266 }
    268 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    269     const HloPosition& position) const {
    270   return GetValueSet(position.instruction, position.index);
    271 }
    273 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
    274   return GetValueSet(position.instruction, position.index);
    275 }
    277 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
    278   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
    279   const InstructionValueSet& operand_set =
    280       GetInstructionValueSet(bitcast->operand(0));
    281   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
    282   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
    283     bitcast_set = operand_set;
    284     return true;
    285   }
    286   return false;
    287 }
    289 bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) {
    290   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
    291   if (!slice->IsInPlaceSlice()) {
    292     return false;
    293   }
    294   // If this slice is lowered to an in-place version, then it forwards the
    295   // operand value to the output.
    296   const InstructionValueSet& operand_set =
    297       GetInstructionValueSet(slice->operand(0));
    298   InstructionValueSet& slice_set = GetInstructionValueSet(slice);
    299   if (operand_set != slice_set) {
    300     slice_set = operand_set;
    301     return true;
    302   }
    303   return false;
    304 }
    306 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
    307   CHECK_EQ(send->opcode(), HloOpcode::kSend);
    308   bool changed = false;
    309   // Send forwards the operand value to the output tuple at {0}.
    310   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
    311     const ShapeIndex& operand_index = pair.first;
    312     const HloValueSet& operand_value_set = pair.second;
    314     ShapeIndex index = {0};
    315     for (int64 i : operand_index) {
    316       index.push_back(i);
    317     }
    319     HloValueSet& value_set = GetValueSet(send, index);
    320     if (value_set != operand_value_set) {
    321       value_set = operand_value_set;
    322       changed = true;
    323     }
    324   }
    325   return changed;
    326 }
    328 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
    329   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
    330   bool changed = false;
    331   // RecvDone forwards the operand value at {0} to the output.
    332   for (auto& pair : GetInstructionValueSet(recv_done)) {
    333     ShapeIndex& index = pair.first;
    334     HloValueSet& value_set = pair.second;
    336     ShapeIndex operand_index = {0};
    337     for (int64 i : index) {
    338       operand_index.push_back(i);
    339     }
    341     const HloValueSet& operand_value_set =
    342         GetValueSet(recv_done->operand(0), operand_index);
    343     if (value_set != operand_value_set) {
    344       value_set = operand_value_set;
    345       changed = true;
    346     }
    347   }
    348   return changed;
    349 }
    351 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
    352   CHECK_EQ(call->opcode(), HloOpcode::kCall);
    353   InstructionValueSet& value_set = GetInstructionValueSet(call);
    354   InstructionValueSet& root_value_set =
    355       GetInstructionValueSet(call->to_apply()->root_instruction());
    356   if (value_set != root_value_set) {
    357     value_set = root_value_set;
    358     return true;
    359   }
    360   return false;
    361 }
    363 bool HloDataflowAnalysis::UpdateConditionalValueSet(
    364     HloInstruction* conditional) {
    365   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
    366   std::vector<const InstructionValueSet*> inputs = {
    367       &GetInstructionValueSet(
    368           conditional->true_computation()->root_instruction()),
    369       &GetInstructionValueSet(
    370           conditional->false_computation()->root_instruction())};
    371   // A phi-node is not defined for a kConditional instruction even though it
    372   // represents a join point. This is because the current approach is to define
    373   // a phi-node only for kWhile to account for the dataflow through back-edges
    374   // and deal with the ambiguity in other cases.
    375   return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
    376 }
    378 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
    379   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
    380   bool changed = false;
    381   for (auto& pair : GetInstructionValueSet(copy)) {
    382     const ShapeIndex& index = pair.first;
    383     if (index.empty()) {
    384       // kCopy shallow copies and thus defines the top-level value so nothing to
    385       // update.
    386       continue;
    387     }
    389     HloValueSet& value_set = pair.second;
    390     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
    391     if (value_set != operand_value_set) {
    392       value_set = operand_value_set;
    393       changed = true;
    394     }
    395   }
    396   return changed;
    397 }
    399 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
    400   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
    401   bool changed = false;
    402   // The GetTupleElement instruction forwards the values from the specified
    403   // tuple element.
    404   for (auto& pair : GetInstructionValueSet(gte)) {
    405     const ShapeIndex& index = pair.first;
    406     HloValueSet& value_set = pair.second;
    408     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
    409     // with the tuple element number prefixed.
    410     ShapeIndex operand_index = {gte->tuple_index()};
    411     for (int64 i : index) {
    412       operand_index.push_back(i);
    413     }
    415     HloValueSet& operand_value_set =
    416         GetValueSet(gte->operand(0), operand_index);
    417     if (value_set != operand_value_set) {
    418       value_set = operand_value_set;
    419       changed = true;
    420     }
    421   }
    422   return changed;
    423 }
    425 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
    426   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
    427   const CallGraphNode& call_graph_node =
    428       call_graph_->GetNode(parameter->parent());
    430   // Subcomputations called in a parallel context (eg, map) do not have dataflow
    431   // from the caller operands.
    432   if (call_graph_node.context() == CallContext::kParallel ||
    433       call_graph_node.caller_callsites().empty()) {
    434     return false;
    435   }
    436   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
    438   std::vector<const InstructionValueSet*> inputs;
    439   bool need_phi = false;
    440   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    441     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
    442       // The operand values of a call instruction are forwarded to the
    443       // respective parameter instruction of the subcomputation.
    444       inputs.push_back(&GetInstructionValueSet(
    445           callsite.instruction()->operand(parameter->parameter_number())));
    446     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    447       // In a while instruction, the while operand (ie, the init value) and the
    448       // backedge are dataflow inputs to the parameter instruction. This is the
    449       // case for parameters of both the body and condition computations.
    450       CHECK_EQ(parameter->parameter_number(), 0);
    451       inputs.push_back(
    452           &GetInstructionValueSet(callsite.instruction()->operand(0)));
    453       // If the parameter *is* the root, then don't consider it's current state
    454       // (InstructionValueSet) as we are recomputing its current
    455       // state. Otherwise, the parameter state would never be updated.
    456       if (parameter !=
    457           callsite.instruction()->while_body()->root_instruction()) {
    458         inputs.push_back(&GetInstructionValueSet(
    459             callsite.instruction()->while_body()->root_instruction()));
    460       }
    461       need_phi = true;
    462     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
    463       CHECK_EQ(parameter->parameter_number(), 0);
    464       auto conditional = callsite.instruction();
    465       // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is
    466       // the argument to the true computation and operand 2 is the argument to
    467       // the false computation.
    468       //
    469       // If the parameter belongs to conditional's true computation, then
    470       // operand 1 is forwarded to this parameter instruction. If the parameter
    471       // belongs to conditional's false computation, then operand 2 is forwarded
    472       // to this parameter instruction.
    473       if (parameter->parent() == conditional->true_computation()) {
    474         inputs.push_back(&GetInstructionValueSet(conditional->operand(1)));
    475       } else {
    476         CHECK_EQ(parameter->parent(), conditional->false_computation());
    477         inputs.push_back(&GetInstructionValueSet(conditional->operand(2)));
    478       }
    479       need_phi = true;
    480     } else {
    481       LOG(FATAL) << "CallContext::kSequential computations should only be "
    482                     "called from call, while, or conditional instructions";
    483     }
    484   }
    486   if (ssa_form_ && need_phi) {
    487     return Phi(parameter, inputs);
    488   } else {
    489     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
    490   }
    491 }
    493 bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) {
    494   CHECK_EQ(select->opcode(), HloOpcode::kSelect);
    495   // A phi value is not defined at a kSelect instruction because kSelect does
    496   // not create a new value. Rather it forwards a value from its operands. This
    497   // contrasts with kWhile instruction (which does define a phi value) which has
    498   // in-place update semantics.
    499   bool changed = false;
    500   for (auto& pair : GetInstructionValueSet(select)) {
    501     const ShapeIndex& index = pair.first;
    502     if (index.empty()) {
    503       // kSelect copies (not forwards) the top-level value.
    504       continue;
    505     }
    506     HloValueSet& value_set = pair.second;
    507     changed |=
    508         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
    509                                  &GetValueSet(select->operand(2), index)});
    510   }
    511   return changed;
    512 }
    514 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
    515   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
    516   bool changed = false;
    517   for (int64 i = 0; i < tuple->operands().size(); ++i) {
    518     // Copy the value set(s) of each operand into the respective position in the
    519     // kTuple instruction's value sets.
    520     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
    521       const ShapeIndex& operand_index = pair.first;
    522       HloValueSet& operand_value_set = pair.second;
    524       ShapeIndex index = {i};
    525       for (int64 op_index : operand_index) {
    526         index.push_back(op_index);
    527       }
    528       HloValueSet& value_set = GetValueSet(tuple, index);
    530       if (value_set != operand_value_set) {
    531         value_set = operand_value_set;
    532         changed = true;
    533       }
    534     }
    535   }
    536   return changed;
    537 }
    539 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
    540   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
    541   std::vector<const InstructionValueSet*> inputs = {
    542       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
    543       &GetInstructionValueSet(xla_while->operand(0))};
    544   if (ssa_form_) {
    545     return Phi(xla_while, inputs);
    546   } else {
    547     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
    548   }
    549 }
    551 bool HloDataflowAnalysis::UpdateInstructionValueSet(
    552     HloInstruction* instruction) {
    553   // Recompute from operands.
    554   switch (instruction->opcode()) {
    555     case HloOpcode::kBitcast:
    556       return UpdateBitcastValueSet(instruction);
    557     case HloOpcode::kSlice:
    558       return UpdateSliceValueSet(instruction);
    559     case HloOpcode::kCopy:
    560       return UpdateCopyValueSet(instruction);
    561     case HloOpcode::kGetTupleElement:
    562       return UpdateGetTupleElementValueSet(instruction);
    563     case HloOpcode::kSelect:
    564       return UpdateSelectValueSet(instruction);
    565     case HloOpcode::kTuple:
    566       return UpdateTupleValueSet(instruction);
    567     case HloOpcode::kParameter:
    568       return UpdateParameterValueSet(instruction);
    569     case HloOpcode::kCall:
    570       return UpdateCallValueSet(instruction);
    571     case HloOpcode::kWhile:
    572       return UpdateWhileValueSet(instruction);
    573     case HloOpcode::kSend:
    574       return UpdateSendValueSet(instruction);
    575     case HloOpcode::kRecvDone:
    576       return UpdateRecvDoneValueSet(instruction);
    577     case HloOpcode::kConditional:
    578       return UpdateConditionalValueSet(instruction);
    579     default:
    580       // Instruction does not forward HloValues (it defines all values in its
    581       // output). No update is necessary.
    582       return false;
    583   }
    584 }
    586 void HloDataflowAnalysis::Propagate() {
    587   std::queue<HloInstruction*> worklist;
    588   tensorflow::gtl::FlatSet<HloInstruction*> workset;
    589   auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
    590     if (workset.insert(instruction).second) {
    591       worklist.push(instruction);
    592     }
    593   };
    595   for (HloComputation* computation : module_.computations()) {
    596     for (HloInstruction* instruction : computation->instructions()) {
    597       add_to_worklist(instruction);
    598     }
    599   }
    601   while (!worklist.empty()) {
    602     HloInstruction* instruction = worklist.front();
    603     worklist.pop();
    604     workset.erase(workset.find(instruction));
    606     VLOG(3) << "Worklist top: " << instruction->name();
    607     VLOG(3) << ToString();
    609     if (!UpdateInstructionValueSet(instruction)) {
    610       // No change to the instruction's value set.
    611       VLOG(4) << "No change.";
    612       continue;
    613     }
    615     VLOG(4) << "New value set for " << instruction->name() << ": "
    616             << GetInstructionValueSet(instruction);
    618     // Instruction value was updated. Add users to work list if we haven't
    619     // already.
    620     for (HloInstruction* user : instruction->users()) {
    621       add_to_worklist(user);
    623       // If user sequentially calls a computation, then the respective
    624       // parameter(s) of the computation need to be updated.
    625       if (user->opcode() == HloOpcode::kConditional) {
    626         // If operand 0 is the use of instruction, then no parameters need to be
    627         // updated, since that is the predicate of the conditional.
    628         // If operand 1 is the use of instruction, then the true_computation's
    629         // parameter need to be updated.
    630         // If operand 2 is the use of instruction, then the false_computation's
    631         // parameter need to be updated.
    632         //
    633         // Note that the same instruction can be used in both operand 1 and
    634         // operand 2.
    635         if (user->operand(1) == instruction) {
    636           add_to_worklist(user->true_computation()->parameter_instruction(0));
    637         }
    638         if (user->operand(2) == instruction) {
    639           add_to_worklist(user->false_computation()->parameter_instruction(0));
    640         }
    641       } else {
    642         for (HloComputation* called_computation : user->called_computations()) {
    643           const CallGraphNode& call_graph_node =
    644               call_graph_->GetNode(called_computation);
    645           if (call_graph_node.context() == CallContext::kSequential) {
    646             for (int64 operand_number : user->OperandIndices(instruction)) {
    647               add_to_worklist(
    648                   called_computation->parameter_instruction(operand_number));
    649             }
    650           }
    651         }
    652       }
    653     }
    655     // If instruction is a root instruction, then propagate out to any calling
    656     // instruction and across any while backedge.
    657     if (instruction == instruction->parent()->root_instruction()) {
    658       const CallGraphNode& call_graph_node =
    659           call_graph_->GetNode(instruction->parent());
    660       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    661         if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
    662             (callsite.instruction()->opcode() == HloOpcode::kConditional)) {
    663           add_to_worklist(callsite.instruction());
    664         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    665           // Add the while itself, and the body and condition parameters.
    666           add_to_worklist(callsite.instruction());
    667           add_to_worklist(
    668               callsite.instruction()->while_body()->parameter_instruction(0));
    669           add_to_worklist(
    670               callsite.instruction()->while_condition()->parameter_instruction(
    671                   0));
    672         }
    673       }
    674     }
    675   }
    676 }
    678 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    679     const HloInstruction* instruction) const {
    680   return value_sets_.at(instruction);
    681 }
    683 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    684     const HloInstruction* instruction) {
    685   return value_sets_.at(instruction);
    686 }
    688 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
    689   for (const HloComputation* computation : module_.computations()) {
    690     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
    691     for (HloInstruction* instruction : computation->instructions()) {
    692       // Create an empty shape tree.
    693       value_sets_.emplace(std::piecewise_construct,
    694                           std::forward_as_tuple(instruction),
    695                           std::forward_as_tuple(instruction->shape()));
    697       // Lambda to set the value set to define all values in the output of the
    698       // instruction.
    699       auto define_all_values = [this, &instruction](bool is_phi = false) {
    700         for (auto& pair : GetInstructionValueSet(instruction)) {
    701           const ShapeIndex& index = pair.first;
    702           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    703           GetValueSet(instruction, index).AddValue(value);
    704         }
    705       };
    707       // Lambda to set the value set to define only the top-level buffer in the
    708       // output of the instruction. Any other values flow from the operands of
    709       // the instruction (or from cross-computation dataflow).
    710       auto define_top_level_only = [this, &instruction]() {
    711         HloValue* value =
    712             NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false);
    713         GetValueSet(instruction, /*index=*/{}).AddValue(value);
    714       };
    716       // Lambda to set the value set at the given index of the output.
    717       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
    718         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    719         GetValueSet(instruction, index).AddValue(value);
    720       };
    722       switch (instruction->opcode()) {
    723         case HloOpcode::kBitcast:
    724           if (bitcast_defines_value_) {
    725             define_all_values();
    726           }
    727           break;
    728         case HloOpcode::kSlice:
    729           if (!instruction->IsInPlaceSlice()) {
    730             define_all_values();
    731           }
    732           break;
    733         case HloOpcode::kWhile:
    734         case HloOpcode::kCall:
    735         case HloOpcode::kConditional:
    736         case HloOpcode::kGetTupleElement:
    737           // These instructions define no values. The values in their output
    738           // flow from their operands or from cross computation dataflow.
    739           break;
    740         case HloOpcode::kParameter:
    741           if (call_graph_node.context() == CallContext::kBoth) {
    742             // We do not support a subcomputation that is called from both a
    743             // parallel and sequential context. In this case, the parameter
    744             // would both define a value and propagate a value from its
    745             // caller. This limitation is not really a problem because the call
    746             // graph is typically flattened.
    747             return Unimplemented(
    748                 "Computation %s is called in both a parallel (eg, kMap) and "
    749                 "sequential (eg, kCall) context",
    750                 computation->name().c_str());
    751           }
    752           if (call_graph_node.caller_callsites().empty() ||
    753               call_graph_node.context() == CallContext::kParallel) {
    754             // Parameters of computations called in a parallel context (eg, map
    755             // and reduce) as well as parameters of dead computations define all
    756             // values in their output. Otherwise the values of the parameter
    757             // come from the caller (eg, operands to the kCall instruction).
    758             define_all_values();
    759           }
    760           break;
    761         case HloOpcode::kCopy:
    762         case HloOpcode::kSelect:
    763         case HloOpcode::kTuple:
    764           // These instructions only define their top-level values. Any other
    765           // values flow from their operands.
    766           define_top_level_only();
    767           break;
    768         case HloOpcode::kRecvDone:
    769           // RecvDone aliases its input tuple element {0}, therefore does not
    770           // define any values.
    771           break;
    772         case HloOpcode::kSend:
    773           // Send produces a tuple of {aliased operand, U32 context}, therefore
    774           // only defines the top-level tuple and the tuple element at {1}.
    775           define_value_at(/*index=*/{});
    776           define_value_at(/*index=*/{1});
    777           break;
    778         default:
    779           define_all_values();
    780           break;
    781       }
    782     }
    783   }
    785   return Status::OK();
    786 }
    788 /* static */
    789 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
    790     const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
    791   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
    792   XLA_VLOG_LINES(2, module.ToString());
    794   auto dataflow_analysis = WrapUnique(
    795       new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
    797   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
    798   dataflow_analysis->Propagate();
    800   // Delete all values marked for deletion.
    801   dataflow_analysis->DeleteMarkedValues();
    803   // Gather and set all non-definition positions of all values. Value deletion
    804   // is rare, so just use a vector indexed by Value::Id rather than a map from
    805   // Value::Id to positions. There should be very few holes in the vector, and
    806   // lookup is faster.
    807   std::vector<std::vector<HloPosition>> value_positions(
    808       dataflow_analysis->next_value_id_);
    809   for (const HloComputation* computation : module.computations()) {
    810     for (HloInstruction* instruction : computation->instructions()) {
    811       for (const auto& pair :
    812            dataflow_analysis->GetInstructionValueSet(instruction)) {
    813         const ShapeIndex& index = pair.first;
    814         const HloValueSet& value_set = pair.second;
    815         for (const HloValue* value : value_set.values()) {
    816           if (value->defining_instruction() != instruction) {
    817             value_positions[value->id()].push_back(
    818                 HloPosition{instruction, index});
    819           }
    820         }
    821       }
    822     }
    823   }
    824   for (auto& pair : dataflow_analysis->values_) {
    825     HloValue::Id value_id = pair.first;
    826     HloValue& value = pair.second;
    827     value.SetPositionsAndComputeUses(value_positions[value_id]);
    828   }
    830   // Construct vector of values.
    831   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
    832   for (auto& pair : dataflow_analysis->values_) {
    833     dataflow_analysis->values_vector_.push_back(&pair.second);
    834   }
    835   std::sort(dataflow_analysis->values_vector_.begin(),
    836             dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
    838   TF_DCHECK_OK(dataflow_analysis->Verify());
    840   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
    842   return std::move(dataflow_analysis);
    843 }
    845 Status HloDataflowAnalysis::Verify() const {
    846   // Verify each HloValue appears in the value sets that the value's positions()
    847   // indicate.
    848   for (const HloValue* value : values()) {
    849     for (const HloPosition& position : value->positions()) {
    850       const HloValueSet& value_set = GetValueSet(position);
    851       TF_RET_CHECK(std::find(value_set.values().begin(),
    852                              value_set.values().end(),
    853                              value) != value_set.values().end())
    854           << "Value set at position " << position << " does not contain value "
    855           << value->ToShortString();
    856     }
    857   }
    859   // For each value in each value set, verify that the value set's position
    860   // appears in the value's positions().
    861   for (const auto& computation : module_.computations()) {
    862     for (const auto& instruction : computation->instructions()) {
    863       for (const auto& pair : GetInstructionValueSet(instruction)) {
    864         const ShapeIndex& index = pair.first;
    865         const HloValueSet& value_set = pair.second;
    866         const HloPosition position{instruction, index};
    867         for (const HloValue* value : value_set.values()) {
    868           TF_RET_CHECK(std::find(value->positions().begin(),
    869                                  value->positions().end(),
    870                                  position) != value->positions().end())
    871               << "Value set at position " << position
    872               << " unexpectedly contains value " << value->ToShortString();
    873         }
    874       }
    875     }
    876   }
    878   return Status::OK();
    879 }
    881 }  // namespace xla