Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
     17 
     18 #include <algorithm>
     19 #include <queue>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/map_util.h"
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/status.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 
     36 namespace xla {
     37 
     38 using ::tensorflow::strings::StrAppend;
     39 using ::tensorflow::strings::StrCat;
     40 
     41 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
     42                                          bool bitcast_defines_value)
     43     : module_(module),
     44       ssa_form_(ssa_form),
     45       bitcast_defines_value_(bitcast_defines_value),
     46       call_graph_(CallGraph::Build(&module)) {}
     47 
     48 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
     49                                            const ShapeIndex& index) const {
     50   const HloValueSet& value_set = GetValueSet(instruction, index);
     51   if (value_set.values().size() != 1) {
     52     return false;
     53   }
     54   return value_set.GetUniqueValue().defining_instruction() == instruction;
     55 }
     56 
     57 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     58     const HloInstruction* instruction, const ShapeIndex& index) const {
     59   CHECK(ValueIsDefinedAt(instruction, index));
     60   return GetUniqueValueAt(instruction, index);
     61 }
     62 
     63 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     64     const HloInstruction* instruction, const ShapeIndex& index) {
     65   CHECK(ValueIsDefinedAt(instruction, index));
     66   return GetUniqueValueAt(instruction, index);
     67 }
     68 
     69 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
     70                                            const ShapeIndex& index,
     71                                            bool is_phi) {
     72   const int64 value_id = next_value_id_++;
     73   auto emplaced = values_.emplace(
     74       std::piecewise_construct, std::forward_as_tuple(value_id),
     75       std::forward_as_tuple(value_id, instruction, index, is_phi));
     76   CHECK(emplaced.second);
     77 
     78   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
     79 
     80   return &emplaced.first->second;
     81 }
     82 
     83 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
     84   HloValue& value = values_.at(value_id);
     85   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
     86 
     87   value_ids_to_delete_.push_back(value_id);
     88 }
     89 
     90 void HloDataflowAnalysis::DeleteMarkedValues() {
     91 #ifndef NDEBUG
     92   // Verify that no marked-for-deletion values are in any of the value sets.
     93   tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
     94                                                 value_ids_to_delete_.end());
     95   for (const auto& pair : value_sets_) {
     96     const HloInstruction* instruction = pair.first;
     97     const InstructionValueSet& instruction_value_set = pair.second;
     98     for (const auto& index_value_set : instruction_value_set) {
     99       const HloValueSet& value_set = index_value_set.second;
    100       for (const HloValue* value : value_set.values()) {
    101         DCHECK(!ContainsKey(id_set, value->id()))
    102             << "Value " << value->ToShortString()
    103             << " marked for deletion, but still exists in value set for "
    104                "instruction "
    105             << instruction->name();
    106       }
    107     }
    108   }
    109 #endif
    110 
    111   for (HloValue::Id value_id : value_ids_to_delete_) {
    112     values_.erase(value_id);
    113   }
    114   value_ids_to_delete_.clear();
    115 }
    116 
    117 string HloDataflowAnalysis::ToString() const {
    118   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
    119   StrAppend(&out, "  Instruction value sets:\n");
    120   for (const HloComputation* computation : module_.computations()) {
    121     for (const HloInstruction* instruction : computation->instructions()) {
    122       StrAppend(&out, "    ", instruction->name(), ":\n");
    123       if (ShapeUtil::IsTuple(instruction->shape())) {
    124         GetInstructionValueSet(instruction)
    125             .ForEachElement([this, &instruction, &out](
    126                                 const ShapeIndex& index,
    127                                 const HloValueSet& value_set) {
    128               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
    129               for (const HloValue* value : value_set.values()) {
    130                 StrAppend(&out, "        ", value->ToShortString(),
    131                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
    132                           "\n");
    133               }
    134             });
    135       } else {
    136         const HloValueSet& top_level_value_set =
    137             GetValueSet(instruction, /*index=*/{});
    138         for (const HloValue* value : top_level_value_set.values()) {
    139           StrAppend(&out, "      ", value->ToShortString(),
    140                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
    141         }
    142       }
    143     }
    144   }
    145   StrAppend(&out, "  HloValues:\n");
    146   for (const HloValue* value : values()) {
    147     StrAppend(&out, value->ToString(/*indent=*/4));
    148   }
    149   return out;
    150 }
    151 
    152 bool HloDataflowAnalysis::Phi(
    153     HloInstruction* instruction,
    154     tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
    155   CHECK(ssa_form_);
    156   VLOG(4) << "Phi(" << instruction->name() << ")";
    157   VLOG(5) << "instruction value set = "
    158           << GetInstructionValueSet(instruction).ToString();
    159   for (const InstructionValueSet* input : inputs) {
    160     VLOG(5) << "input value set = " << input->ToString();
    161   }
    162   for (const InstructionValueSet* input : inputs) {
    163     DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
    164   }
    165 
    166   bool changed = false;
    167   for (auto& pair : GetInstructionValueSet(instruction)) {
    168     const ShapeIndex& index = pair.first;
    169     HloValueSet& value_set = pair.second;
    170 
    171     // Positions with phi values should never have more than one value in the
    172     // value set.
    173     CHECK_LE(value_set.values().size(), 1);
    174     const HloValue* current_value =
    175         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
    176 
    177     // Construct a vector of unique value IDs of the inputs.
    178     // Don't add value ids where the input is equal to the definition.
    179     std::vector<HloValue::Id> input_value_ids;
    180     for (const InstructionValueSet* input : inputs) {
    181       for (const HloValue* value : input->element(index).values()) {
    182         if (value->defining_instruction() == instruction &&
    183             value->defining_index() == index) {
    184           continue;
    185         }
    186         input_value_ids.push_back(value->id());
    187       }
    188     }
    189     std::sort(input_value_ids.begin(), input_value_ids.end());
    190     input_value_ids.erase(
    191         std::unique(input_value_ids.begin(), input_value_ids.end()),
    192         input_value_ids.end());
    193 
    194     // Remove the existing phi value (if it exists). The phi can be its own
    195     // input, for example, in while body parameters where the body passes
    196     // through the parameter value.
    197     bool current_value_defined_here =
    198         (current_value != nullptr &&
    199          current_value->defining_instruction() == instruction &&
    200          current_value->defining_index() == index);
    201     if (current_value_defined_here) {
    202       VLOG(5) << "current_value_defined_here: " << current_value->ToString();
    203       CHECK(current_value->is_phi());
    204       auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
    205                           current_value->id());
    206       if (it != input_value_ids.end()) {
    207         input_value_ids.erase(it);
    208       }
    209     }
    210     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
    211     if (input_value_ids.empty()) {
    212       // A value set which has at least one element should never have its value
    213       // set reduced to zero elements. During dataflow value sets only can go
    214       // from empty to non-empty, not the reverse.
    215       CHECK_EQ(value_set.values().size(), 0)
    216           << "Instruction " << instruction->name() << " at index " << index
    217           << " previously had non-empty value set. Value set: " << value_set;
    218     } else if (input_value_ids.size() == 1) {
    219       // Only a single value reaches this point. There should be no phi, and
    220       // this value set should contain this single value.
    221       const HloValue& new_value = GetValue(input_value_ids[0]);
    222       if (current_value == nullptr) {
    223         value_set.Clear();
    224         value_set.AddValue(&new_value);
    225         changed = true;
    226       } else if (current_value != &new_value) {
    227         if (current_value_defined_here) {
    228           // Remove the existing phi.
    229           MarkValueForDeletion(current_value->id());
    230         }
    231         value_set.Clear();
    232         value_set.AddValue(&new_value);
    233         changed = true;
    234       }
    235     } else {
    236       // Multiple distinct values reach this point. A phi value is
    237       // necessary.
    238       CHECK_GT(input_value_ids.size(), 1);
    239       if (current_value == nullptr ||
    240           !(current_value->is_phi() && current_value_defined_here)) {
    241         value_set.Clear();
    242         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
    243         changed = true;
    244       }
    245     }
    246   }
    247   return changed;
    248 }
    249 
    250 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
    251   return values_.at(value_id);
    252 }
    253 
    254 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
    255   return values_.at(value_id);
    256 }
    257 
    258 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    259     const HloInstruction* instruction, const ShapeIndex& index) const {
    260   return GetInstructionValueSet(instruction).element(index);
    261 }
    262 
    263 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
    264                                               const ShapeIndex& index) {
    265   return *GetInstructionValueSet(instruction).mutable_element(index);
    266 }
    267 
    268 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    269     const HloPosition& position) const {
    270   return GetValueSet(position.instruction, position.index);
    271 }
    272 
    273 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
    274   return GetValueSet(position.instruction, position.index);
    275 }
    276 
    277 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
    278   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
    279   const InstructionValueSet& operand_set =
    280       GetInstructionValueSet(bitcast->operand(0));
    281   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
    282   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
    283     bitcast_set = operand_set;
    284     return true;
    285   }
    286   return false;
    287 }
    288 
    289 bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) {
    290   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
    291   if (!slice->IsInPlaceSlice()) {
    292     return false;
    293   }
    294   // If this slice is lowered to an in-place version, then it forwards the
    295   // operand value to the output.
    296   const InstructionValueSet& operand_set =
    297       GetInstructionValueSet(slice->operand(0));
    298   InstructionValueSet& slice_set = GetInstructionValueSet(slice);
    299   if (operand_set != slice_set) {
    300     slice_set = operand_set;
    301     return true;
    302   }
    303   return false;
    304 }
    305 
    306 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
    307   CHECK_EQ(send->opcode(), HloOpcode::kSend);
    308   bool changed = false;
    309   // Send forwards the operand value to the output tuple at {0}.
    310   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
    311     const ShapeIndex& operand_index = pair.first;
    312     const HloValueSet& operand_value_set = pair.second;
    313 
    314     ShapeIndex index = {0};
    315     for (int64 i : operand_index) {
    316       index.push_back(i);
    317     }
    318 
    319     HloValueSet& value_set = GetValueSet(send, index);
    320     if (value_set != operand_value_set) {
    321       value_set = operand_value_set;
    322       changed = true;
    323     }
    324   }
    325   return changed;
    326 }
    327 
    328 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
    329   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
    330   bool changed = false;
    331   // RecvDone forwards the operand value at {0} to the output.
    332   for (auto& pair : GetInstructionValueSet(recv_done)) {
    333     ShapeIndex& index = pair.first;
    334     HloValueSet& value_set = pair.second;
    335 
    336     ShapeIndex operand_index = {0};
    337     for (int64 i : index) {
    338       operand_index.push_back(i);
    339     }
    340 
    341     const HloValueSet& operand_value_set =
    342         GetValueSet(recv_done->operand(0), operand_index);
    343     if (value_set != operand_value_set) {
    344       value_set = operand_value_set;
    345       changed = true;
    346     }
    347   }
    348   return changed;
    349 }
    350 
    351 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
    352   CHECK_EQ(call->opcode(), HloOpcode::kCall);
    353   InstructionValueSet& value_set = GetInstructionValueSet(call);
    354   InstructionValueSet& root_value_set =
    355       GetInstructionValueSet(call->to_apply()->root_instruction());
    356   if (value_set != root_value_set) {
    357     value_set = root_value_set;
    358     return true;
    359   }
    360   return false;
    361 }
    362 
    363 bool HloDataflowAnalysis::UpdateConditionalValueSet(
    364     HloInstruction* conditional) {
    365   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
    366   std::vector<const InstructionValueSet*> inputs = {
    367       &GetInstructionValueSet(
    368           conditional->true_computation()->root_instruction()),
    369       &GetInstructionValueSet(
    370           conditional->false_computation()->root_instruction())};
    371   // A phi-node is not defined for a kConditional instruction even though it
    372   // represents a join point. This is because the current approach is to define
    373   // a phi-node only for kWhile to account for the dataflow through back-edges
    374   // and deal with the ambiguity in other cases.
    375   return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
    376 }
    377 
    378 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
    379   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
    380   bool changed = false;
    381   for (auto& pair : GetInstructionValueSet(copy)) {
    382     const ShapeIndex& index = pair.first;
    383     if (index.empty()) {
    384       // kCopy shallow copies and thus defines the top-level value so nothing to
    385       // update.
    386       continue;
    387     }
    388 
    389     HloValueSet& value_set = pair.second;
    390     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
    391     if (value_set != operand_value_set) {
    392       value_set = operand_value_set;
    393       changed = true;
    394     }
    395   }
    396   return changed;
    397 }
    398 
    399 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
    400   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
    401   bool changed = false;
    402   // The GetTupleElement instruction forwards the values from the specified
    403   // tuple element.
    404   for (auto& pair : GetInstructionValueSet(gte)) {
    405     const ShapeIndex& index = pair.first;
    406     HloValueSet& value_set = pair.second;
    407 
    408     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
    409     // with the tuple element number prefixed.
    410     ShapeIndex operand_index = {gte->tuple_index()};
    411     for (int64 i : index) {
    412       operand_index.push_back(i);
    413     }
    414 
    415     HloValueSet& operand_value_set =
    416         GetValueSet(gte->operand(0), operand_index);
    417     if (value_set != operand_value_set) {
    418       value_set = operand_value_set;
    419       changed = true;
    420     }
    421   }
    422   return changed;
    423 }
    424 
    425 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
    426   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
    427   const CallGraphNode& call_graph_node =
    428       call_graph_->GetNode(parameter->parent());
    429 
    430   // Subcomputations called in a parallel context (eg, map) do not have dataflow
    431   // from the caller operands.
    432   if (call_graph_node.context() == CallContext::kParallel ||
    433       call_graph_node.caller_callsites().empty()) {
    434     return false;
    435   }
    436   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
    437 
    438   std::vector<const InstructionValueSet*> inputs;
    439   bool need_phi = false;
    440   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    441     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
    442       // The operand values of a call instruction are forwarded to the
    443       // respective parameter instruction of the subcomputation.
    444       inputs.push_back(&GetInstructionValueSet(
    445           callsite.instruction()->operand(parameter->parameter_number())));
    446     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    447       // In a while instruction, the while operand (ie, the init value) and the
    448       // backedge are dataflow inputs to the parameter instruction. This is the
    449       // case for parameters of both the body and condition computations.
    450       CHECK_EQ(parameter->parameter_number(), 0);
    451       inputs.push_back(
    452           &GetInstructionValueSet(callsite.instruction()->operand(0)));
    453       // If the parameter *is* the root, then don't consider it's current state
    454       // (InstructionValueSet) as we are recomputing its current
    455       // state. Otherwise, the parameter state would never be updated.
    456       if (parameter !=
    457           callsite.instruction()->while_body()->root_instruction()) {
    458         inputs.push_back(&GetInstructionValueSet(
    459             callsite.instruction()->while_body()->root_instruction()));
    460       }
    461       need_phi = true;
    462     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
    463       CHECK_EQ(parameter->parameter_number(), 0);
    464       auto conditional = callsite.instruction();
    465       // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is
    466       // the argument to the true computation and operand 2 is the argument to
    467       // the false computation.
    468       //
    469       // If the parameter belongs to conditional's true computation, then
    470       // operand 1 is forwarded to this parameter instruction. If the parameter
    471       // belongs to conditional's false computation, then operand 2 is forwarded
    472       // to this parameter instruction.
    473       if (parameter->parent() == conditional->true_computation()) {
    474         inputs.push_back(&GetInstructionValueSet(conditional->operand(1)));
    475       } else {
    476         CHECK_EQ(parameter->parent(), conditional->false_computation());
    477         inputs.push_back(&GetInstructionValueSet(conditional->operand(2)));
    478       }
    479       need_phi = true;
    480     } else {
    481       LOG(FATAL) << "CallContext::kSequential computations should only be "
    482                     "called from call, while, or conditional instructions";
    483     }
    484   }
    485 
    486   if (ssa_form_ && need_phi) {
    487     return Phi(parameter, inputs);
    488   } else {
    489     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
    490   }
    491 }
    492 
    493 bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) {
    494   CHECK_EQ(select->opcode(), HloOpcode::kSelect);
    495   // A phi value is not defined at a kSelect instruction because kSelect does
    496   // not create a new value. Rather it forwards a value from its operands. This
    497   // contrasts with kWhile instruction (which does define a phi value) which has
    498   // in-place update semantics.
    499   bool changed = false;
    500   for (auto& pair : GetInstructionValueSet(select)) {
    501     const ShapeIndex& index = pair.first;
    502     if (index.empty()) {
    503       // kSelect copies (not forwards) the top-level value.
    504       continue;
    505     }
    506     HloValueSet& value_set = pair.second;
    507     changed |=
    508         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
    509                                  &GetValueSet(select->operand(2), index)});
    510   }
    511   return changed;
    512 }
    513 
    514 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
    515   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
    516   bool changed = false;
    517   for (int64 i = 0; i < tuple->operands().size(); ++i) {
    518     // Copy the value set(s) of each operand into the respective position in the
    519     // kTuple instruction's value sets.
    520     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
    521       const ShapeIndex& operand_index = pair.first;
    522       HloValueSet& operand_value_set = pair.second;
    523 
    524       ShapeIndex index = {i};
    525       for (int64 op_index : operand_index) {
    526         index.push_back(op_index);
    527       }
    528       HloValueSet& value_set = GetValueSet(tuple, index);
    529 
    530       if (value_set != operand_value_set) {
    531         value_set = operand_value_set;
    532         changed = true;
    533       }
    534     }
    535   }
    536   return changed;
    537 }
    538 
    539 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
    540   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
    541   std::vector<const InstructionValueSet*> inputs = {
    542       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
    543       &GetInstructionValueSet(xla_while->operand(0))};
    544   if (ssa_form_) {
    545     return Phi(xla_while, inputs);
    546   } else {
    547     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
    548   }
    549 }
    550 
    551 bool HloDataflowAnalysis::UpdateInstructionValueSet(
    552     HloInstruction* instruction) {
    553   // Recompute from operands.
    554   switch (instruction->opcode()) {
    555     case HloOpcode::kBitcast:
    556       return UpdateBitcastValueSet(instruction);
    557     case HloOpcode::kSlice:
    558       return UpdateSliceValueSet(instruction);
    559     case HloOpcode::kCopy:
    560       return UpdateCopyValueSet(instruction);
    561     case HloOpcode::kGetTupleElement:
    562       return UpdateGetTupleElementValueSet(instruction);
    563     case HloOpcode::kSelect:
    564       return UpdateSelectValueSet(instruction);
    565     case HloOpcode::kTuple:
    566       return UpdateTupleValueSet(instruction);
    567     case HloOpcode::kParameter:
    568       return UpdateParameterValueSet(instruction);
    569     case HloOpcode::kCall:
    570       return UpdateCallValueSet(instruction);
    571     case HloOpcode::kWhile:
    572       return UpdateWhileValueSet(instruction);
    573     case HloOpcode::kSend:
    574       return UpdateSendValueSet(instruction);
    575     case HloOpcode::kRecvDone:
    576       return UpdateRecvDoneValueSet(instruction);
    577     case HloOpcode::kConditional:
    578       return UpdateConditionalValueSet(instruction);
    579     default:
    580       // Instruction does not forward HloValues (it defines all values in its
    581       // output). No update is necessary.
    582       return false;
    583   }
    584 }
    585 
    586 void HloDataflowAnalysis::Propagate() {
    587   std::queue<HloInstruction*> worklist;
    588   tensorflow::gtl::FlatSet<HloInstruction*> workset;
    589   auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
    590     if (workset.insert(instruction).second) {
    591       worklist.push(instruction);
    592     }
    593   };
    594 
    595   for (HloComputation* computation : module_.computations()) {
    596     for (HloInstruction* instruction : computation->instructions()) {
    597       add_to_worklist(instruction);
    598     }
    599   }
    600 
    601   while (!worklist.empty()) {
    602     HloInstruction* instruction = worklist.front();
    603     worklist.pop();
    604     workset.erase(workset.find(instruction));
    605 
    606     VLOG(3) << "Worklist top: " << instruction->name();
    607     VLOG(3) << ToString();
    608 
    609     if (!UpdateInstructionValueSet(instruction)) {
    610       // No change to the instruction's value set.
    611       VLOG(4) << "No change.";
    612       continue;
    613     }
    614 
    615     VLOG(4) << "New value set for " << instruction->name() << ": "
    616             << GetInstructionValueSet(instruction);
    617 
    618     // Instruction value was updated. Add users to work list if we haven't
    619     // already.
    620     for (HloInstruction* user : instruction->users()) {
    621       add_to_worklist(user);
    622 
    623       // If user sequentially calls a computation, then the respective
    624       // parameter(s) of the computation need to be updated.
    625       if (user->opcode() == HloOpcode::kConditional) {
    626         // If operand 0 is the use of instruction, then no parameters need to be
    627         // updated, since that is the predicate of the conditional.
    628         // If operand 1 is the use of instruction, then the true_computation's
    629         // parameter need to be updated.
    630         // If operand 2 is the use of instruction, then the false_computation's
    631         // parameter need to be updated.
    632         //
    633         // Note that the same instruction can be used in both operand 1 and
    634         // operand 2.
    635         if (user->operand(1) == instruction) {
    636           add_to_worklist(user->true_computation()->parameter_instruction(0));
    637         }
    638         if (user->operand(2) == instruction) {
    639           add_to_worklist(user->false_computation()->parameter_instruction(0));
    640         }
    641       } else {
    642         for (HloComputation* called_computation : user->called_computations()) {
    643           const CallGraphNode& call_graph_node =
    644               call_graph_->GetNode(called_computation);
    645           if (call_graph_node.context() == CallContext::kSequential) {
    646             for (int64 operand_number : user->OperandIndices(instruction)) {
    647               add_to_worklist(
    648                   called_computation->parameter_instruction(operand_number));
    649             }
    650           }
    651         }
    652       }
    653     }
    654 
    655     // If instruction is a root instruction, then propagate out to any calling
    656     // instruction and across any while backedge.
    657     if (instruction == instruction->parent()->root_instruction()) {
    658       const CallGraphNode& call_graph_node =
    659           call_graph_->GetNode(instruction->parent());
    660       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    661         if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
    662             (callsite.instruction()->opcode() == HloOpcode::kConditional)) {
    663           add_to_worklist(callsite.instruction());
    664         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    665           // Add the while itself, and the body and condition parameters.
    666           add_to_worklist(callsite.instruction());
    667           add_to_worklist(
    668               callsite.instruction()->while_body()->parameter_instruction(0));
    669           add_to_worklist(
    670               callsite.instruction()->while_condition()->parameter_instruction(
    671                   0));
    672         }
    673       }
    674     }
    675   }
    676 }
    677 
    678 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    679     const HloInstruction* instruction) const {
    680   return value_sets_.at(instruction);
    681 }
    682 
    683 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    684     const HloInstruction* instruction) {
    685   return value_sets_.at(instruction);
    686 }
    687 
    688 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
    689   for (const HloComputation* computation : module_.computations()) {
    690     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
    691     for (HloInstruction* instruction : computation->instructions()) {
    692       // Create an empty shape tree.
    693       value_sets_.emplace(std::piecewise_construct,
    694                           std::forward_as_tuple(instruction),
    695                           std::forward_as_tuple(instruction->shape()));
    696 
    697       // Lambda to set the value set to define all values in the output of the
    698       // instruction.
    699       auto define_all_values = [this, &instruction](bool is_phi = false) {
    700         for (auto& pair : GetInstructionValueSet(instruction)) {
    701           const ShapeIndex& index = pair.first;
    702           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    703           GetValueSet(instruction, index).AddValue(value);
    704         }
    705       };
    706 
    707       // Lambda to set the value set to define only the top-level buffer in the
    708       // output of the instruction. Any other values flow from the operands of
    709       // the instruction (or from cross-computation dataflow).
    710       auto define_top_level_only = [this, &instruction]() {
    711         HloValue* value =
    712             NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false);
    713         GetValueSet(instruction, /*index=*/{}).AddValue(value);
    714       };
    715 
    716       // Lambda to set the value set at the given index of the output.
    717       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
    718         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    719         GetValueSet(instruction, index).AddValue(value);
    720       };
    721 
    722       switch (instruction->opcode()) {
    723         case HloOpcode::kBitcast:
    724           if (bitcast_defines_value_) {
    725             define_all_values();
    726           }
    727           break;
    728         case HloOpcode::kSlice:
    729           if (!instruction->IsInPlaceSlice()) {
    730             define_all_values();
    731           }
    732           break;
    733         case HloOpcode::kWhile:
    734         case HloOpcode::kCall:
    735         case HloOpcode::kConditional:
    736         case HloOpcode::kGetTupleElement:
    737           // These instructions define no values. The values in their output
    738           // flow from their operands or from cross computation dataflow.
    739           break;
    740         case HloOpcode::kParameter:
    741           if (call_graph_node.context() == CallContext::kBoth) {
    742             // We do not support a subcomputation that is called from both a
    743             // parallel and sequential context. In this case, the parameter
    744             // would both define a value and propagate a value from its
    745             // caller. This limitation is not really a problem because the call
    746             // graph is typically flattened.
    747             return Unimplemented(
    748                 "Computation %s is called in both a parallel (eg, kMap) and "
    749                 "sequential (eg, kCall) context",
    750                 computation->name().c_str());
    751           }
    752           if (call_graph_node.caller_callsites().empty() ||
    753               call_graph_node.context() == CallContext::kParallel) {
    754             // Parameters of computations called in a parallel context (eg, map
    755             // and reduce) as well as parameters of dead computations define all
    756             // values in their output. Otherwise the values of the parameter
    757             // come from the caller (eg, operands to the kCall instruction).
    758             define_all_values();
    759           }
    760           break;
    761         case HloOpcode::kCopy:
    762         case HloOpcode::kSelect:
    763         case HloOpcode::kTuple:
    764           // These instructions only define their top-level values. Any other
    765           // values flow from their operands.
    766           define_top_level_only();
    767           break;
    768         case HloOpcode::kRecvDone:
    769           // RecvDone aliases its input tuple element {0}, therefore does not
    770           // define any values.
    771           break;
    772         case HloOpcode::kSend:
    773           // Send produces a tuple of {aliased operand, U32 context}, therefore
    774           // only defines the top-level tuple and the tuple element at {1}.
    775           define_value_at(/*index=*/{});
    776           define_value_at(/*index=*/{1});
    777           break;
    778         default:
    779           define_all_values();
    780           break;
    781       }
    782     }
    783   }
    784 
    785   return Status::OK();
    786 }
    787 
    788 /* static */
    789 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
    790     const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
    791   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
    792   XLA_VLOG_LINES(2, module.ToString());
    793 
    794   auto dataflow_analysis = WrapUnique(
    795       new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
    796 
    797   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
    798   dataflow_analysis->Propagate();
    799 
    800   // Delete all values marked for deletion.
    801   dataflow_analysis->DeleteMarkedValues();
    802 
    803   // Gather and set all non-definition positions of all values. Value deletion
    804   // is rare, so just use a vector indexed by Value::Id rather than a map from
    805   // Value::Id to positions. There should be very few holes in the vector, and
    806   // lookup is faster.
    807   std::vector<std::vector<HloPosition>> value_positions(
    808       dataflow_analysis->next_value_id_);
    809   for (const HloComputation* computation : module.computations()) {
    810     for (HloInstruction* instruction : computation->instructions()) {
    811       for (const auto& pair :
    812            dataflow_analysis->GetInstructionValueSet(instruction)) {
    813         const ShapeIndex& index = pair.first;
    814         const HloValueSet& value_set = pair.second;
    815         for (const HloValue* value : value_set.values()) {
    816           if (value->defining_instruction() != instruction) {
    817             value_positions[value->id()].push_back(
    818                 HloPosition{instruction, index});
    819           }
    820         }
    821       }
    822     }
    823   }
    824   for (auto& pair : dataflow_analysis->values_) {
    825     HloValue::Id value_id = pair.first;
    826     HloValue& value = pair.second;
    827     value.SetPositionsAndComputeUses(value_positions[value_id]);
    828   }
    829 
    830   // Construct vector of values.
    831   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
    832   for (auto& pair : dataflow_analysis->values_) {
    833     dataflow_analysis->values_vector_.push_back(&pair.second);
    834   }
    835   std::sort(dataflow_analysis->values_vector_.begin(),
    836             dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
    837 
    838   TF_DCHECK_OK(dataflow_analysis->Verify());
    839 
    840   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
    841 
    842   return std::move(dataflow_analysis);
    843 }
    844 
    845 Status HloDataflowAnalysis::Verify() const {
    846   // Verify each HloValue appears in the value sets that the value's positions()
    847   // indicate.
    848   for (const HloValue* value : values()) {
    849     for (const HloPosition& position : value->positions()) {
    850       const HloValueSet& value_set = GetValueSet(position);
    851       TF_RET_CHECK(std::find(value_set.values().begin(),
    852                              value_set.values().end(),
    853                              value) != value_set.values().end())
    854           << "Value set at position " << position << " does not contain value "
    855           << value->ToShortString();
    856     }
    857   }
    858 
    859   // For each value in each value set, verify that the value set's position
    860   // appears in the value's positions().
    861   for (const auto& computation : module_.computations()) {
    862     for (const auto& instruction : computation->instructions()) {
    863       for (const auto& pair : GetInstructionValueSet(instruction)) {
    864         const ShapeIndex& index = pair.first;
    865         const HloValueSet& value_set = pair.second;
    866         const HloPosition position{instruction, index};
    867         for (const HloValue* value : value_set.values()) {
    868           TF_RET_CHECK(std::find(value->positions().begin(),
    869                                  value->positions().end(),
    870                                  position) != value->positions().end())
    871               << "Value set at position " << position
    872               << " unexpectedly contains value " << value->ToShortString();
    873         }
    874       }
    875     }
    876   }
    877 
    878   return Status::OK();
    879 }
    880 
    881 }  // namespace xla
    882