Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
     17 
     18 #include <algorithm>
     19 #include <queue>
     20 #include <vector>
     21 
     22 #include "absl/container/flat_hash_set.h"
     23 #include "absl/container/inlined_vector.h"
     24 #include "absl/memory/memory.h"
     25 #include "absl/strings/str_cat.h"
     26 #include "tensorflow/compiler/xla/map_util.h"
     27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/status.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/compiler/xla/util.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 
     39 namespace xla {
     40 
     41 using absl::StrAppend;
     42 using absl::StrCat;
     43 
     44 HloDataflowAnalysis::HloDataflowAnalysis(
     45     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
     46     const FusionCanShareBufferFunction& fusion_can_share_buffer)
     47     : module_(module),
     48       ssa_form_(ssa_form),
     49       bitcast_defines_value_(bitcast_defines_value),
     50       call_graph_(CallGraph::Build(&module)),
     51       fusion_can_share_buffer_(fusion_can_share_buffer) {}
     52 
     53 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
     54     const HloInstruction* inst) {
     55   absl::flat_hash_set<const HloInstruction*> visited;
     56   absl::InlinedVector<const HloInstruction*, 4> stack;
     57   stack.push_back(inst);
     58   while (!stack.empty()) {
     59     const HloInstruction* current = stack.back();
     60     stack.pop_back();
     61     visited.insert(current);
     62     for (const HloInstruction* user : current->users()) {
     63       // Found a user that is non-elementwise on current instruction.
     64       for (const int64 use_index : user->OperandIndices(current)) {
     65         if (!user->IsElementwiseOnOperand(use_index) &&
     66             user->opcode() != HloOpcode::kTuple) {
     67           return false;
     68         }
     69       }
     70       if (!visited.contains(user)) {
     71         stack.push_back(user);
     72       }
     73     }
     74   }
     75   return true;
     76 }
     77 
     78 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
     79                                            const ShapeIndex& index) const {
     80   const HloValueSet& value_set = GetValueSet(instruction, index);
     81   if (value_set.values().size() != 1) {
     82     return false;
     83   }
     84   return value_set.GetUniqueValue().defining_instruction() == instruction;
     85 }
     86 
     87 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     88     const HloInstruction* instruction, const ShapeIndex& index) const {
     89   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
     90   return GetUniqueValueAt(instruction, index);
     91 }
     92 
     93 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
     94     const HloInstruction* instruction, const ShapeIndex& index) {
     95   CHECK(ValueIsDefinedAt(instruction, index));
     96   return GetUniqueValueAt(instruction, index);
     97 }
     98 
     99 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
    100                                            const ShapeIndex& index,
    101                                            bool is_phi) {
    102   const int64 value_id = next_value_id_++;
    103   auto emplaced = values_.emplace(
    104       std::piecewise_construct, std::forward_as_tuple(value_id),
    105       std::forward_as_tuple(value_id, instruction, index, is_phi));
    106   CHECK(emplaced.second);
    107 
    108   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
    109 
    110   return &emplaced.first->second;
    111 }
    112 
    113 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
    114   HloValue& value = values_.at(value_id);
    115   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
    116 
    117   value_ids_to_delete_.push_back(value_id);
    118 }
    119 
    120 void HloDataflowAnalysis::DeleteMarkedValues() {
    121 #ifndef NDEBUG
    122   // Verify that no marked-for-deletion values are in any of the value sets.
    123   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
    124                                            value_ids_to_delete_.end());
    125   for (const auto& pair : value_sets_) {
    126     const HloInstruction* instruction = pair.first;
    127     const InstructionValueSet& instruction_value_set = pair.second;
    128     for (const auto& index_value_set : instruction_value_set) {
    129       const HloValueSet& value_set = index_value_set.second;
    130       for (const HloValue* value : value_set.values()) {
    131         DCHECK(!ContainsKey(id_set, value->id()))
    132             << "Value " << value->ToShortString()
    133             << " marked for deletion, but still exists in value set for "
    134                "instruction "
    135             << instruction->name();
    136       }
    137     }
    138   }
    139 #endif
    140 
    141   for (HloValue::Id value_id : value_ids_to_delete_) {
    142     values_.erase(value_id);
    143   }
    144   value_ids_to_delete_.clear();
    145 }
    146 
    147 string HloDataflowAnalysis::ToString() const {
    148   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
    149   StrAppend(&out, "  Instruction value sets:\n");
    150   for (const HloComputation* computation : module_.computations()) {
    151     for (const HloInstruction* instruction : computation->instructions()) {
    152       StrAppend(&out, "    ", instruction->name(), ":\n");
    153       if (instruction->shape().IsTuple()) {
    154         GetInstructionValueSet(instruction)
    155             .ForEachElement([this, &instruction, &out](
    156                                 const ShapeIndex& index,
    157                                 const HloValueSet& value_set) {
    158               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
    159               for (const HloValue* value : value_set.values()) {
    160                 StrAppend(&out, "        ", value->ToShortString(),
    161                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
    162                           "\n");
    163               }
    164             });
    165       } else {
    166         const HloValueSet& top_level_value_set =
    167             GetValueSet(instruction, /*index=*/{});
    168         for (const HloValue* value : top_level_value_set.values()) {
    169           StrAppend(&out, "      ", value->ToShortString(),
    170                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
    171         }
    172       }
    173     }
    174   }
    175   StrAppend(&out, "  HloValues:\n");
    176   for (const HloValue* value : values()) {
    177     StrAppend(&out, value->ToString(/*indent=*/4));
    178   }
    179   return out;
    180 }
    181 
    182 bool HloDataflowAnalysis::Phi(
    183     HloInstruction* instruction,
    184     absl::Span<const InstructionValueSet* const> inputs) {
    185   CHECK(ssa_form_);
    186   VLOG(4) << "Phi(" << instruction->name() << ")";
    187   VLOG(5) << "instruction value set = "
    188           << GetInstructionValueSet(instruction).ToString();
    189   for (const InstructionValueSet* input : inputs) {
    190     VLOG(5) << "input value set = " << input->ToString();
    191   }
    192   for (const InstructionValueSet* input : inputs) {
    193     DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
    194   }
    195 
    196   bool changed = false;
    197   for (auto& pair : GetInstructionValueSet(instruction)) {
    198     const ShapeIndex& index = pair.first;
    199     HloValueSet& value_set = pair.second;
    200 
    201     // Positions with phi values should never have more than one value in the
    202     // value set.
    203     CHECK_LE(value_set.values().size(), 1);
    204     const HloValue* current_value =
    205         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
    206 
    207     // Construct a vector of unique value IDs of the inputs.
    208     // Don't add value ids where the input is equal to the definition.
    209     std::vector<HloValue::Id> input_value_ids;
    210     for (const InstructionValueSet* input : inputs) {
    211       for (const HloValue* value : input->element(index).values()) {
    212         if (value->defining_instruction() == instruction &&
    213             value->defining_index() == index) {
    214           continue;
    215         }
    216         input_value_ids.push_back(value->id());
    217       }
    218     }
    219     absl::c_sort(input_value_ids);
    220     input_value_ids.erase(
    221         std::unique(input_value_ids.begin(), input_value_ids.end()),
    222         input_value_ids.end());
    223 
    224     // Remove the existing phi value (if it exists). The phi can be its own
    225     // input, for example, in while body parameters where the body passes
    226     // through the parameter value.
    227     bool current_value_defined_here =
    228         (current_value != nullptr &&
    229          current_value->defining_instruction() == instruction &&
    230          current_value->defining_index() == index);
    231     if (current_value_defined_here) {
    232       VLOG(5) << "current_value_defined_here: " << current_value->ToString();
    233       CHECK(current_value->is_phi());
    234       auto it = absl::c_find(input_value_ids, current_value->id());
    235       if (it != input_value_ids.end()) {
    236         input_value_ids.erase(it);
    237       }
    238     }
    239     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
    240     if (input_value_ids.empty()) {
    241       // A value set which has at least one element should never have its value
    242       // set reduced to zero elements. During dataflow value sets only can go
    243       // from empty to non-empty, not the reverse.
    244       CHECK_EQ(value_set.values().size(), 0)
    245           << "Instruction " << instruction->name() << " at index " << index
    246           << " previously had non-empty value set. Value set: " << value_set;
    247     } else if (input_value_ids.size() == 1) {
    248       // Only a single value reaches this point. There should be no phi, and
    249       // this value set should contain this single value.
    250       const HloValue& new_value = GetValue(input_value_ids[0]);
    251       if (current_value == nullptr) {
    252         value_set.Clear();
    253         value_set.AddValue(&new_value);
    254         changed = true;
    255       } else if (current_value != &new_value) {
    256         if (current_value_defined_here) {
    257           // Remove the existing phi.
    258           MarkValueForDeletion(current_value->id());
    259         }
    260         value_set.Clear();
    261         value_set.AddValue(&new_value);
    262         changed = true;
    263       }
    264     } else {
    265       // Multiple distinct values reach this point. A phi value is
    266       // necessary.
    267       CHECK_GT(input_value_ids.size(), 1);
    268       if (current_value == nullptr ||
    269           !(current_value->is_phi() && current_value_defined_here)) {
    270         value_set.Clear();
    271         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
    272         changed = true;
    273       }
    274     }
    275   }
    276   return changed;
    277 }
    278 
    279 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
    280   return values_.at(value_id);
    281 }
    282 
    283 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
    284   return values_.at(value_id);
    285 }
    286 
    287 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    288     const HloInstruction* instruction, const ShapeIndex& index) const {
    289   return GetInstructionValueSet(instruction).element(index);
    290 }
    291 
    292 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
    293                                               const ShapeIndex& index) {
    294   return *GetInstructionValueSet(instruction).mutable_element(index);
    295 }
    296 
    297 const HloValueSet& HloDataflowAnalysis::GetValueSet(
    298     const HloPosition& position) const {
    299   return GetValueSet(position.instruction, position.index);
    300 }
    301 
    302 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
    303   return GetValueSet(position.instruction, position.index);
    304 }
    305 
    306 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
    307   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
    308   const InstructionValueSet& operand_set =
    309       GetInstructionValueSet(bitcast->operand(0));
    310   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
    311   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
    312     bitcast_set = operand_set;
    313     return true;
    314   }
    315   return false;
    316 }
    317 
    318 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
    319   CHECK_EQ(send->opcode(), HloOpcode::kSend);
    320   bool changed = false;
    321   // Send forwards the operand value to the output tuple at {0}.
    322   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
    323     const ShapeIndex& operand_index = pair.first;
    324     const HloValueSet& operand_value_set = pair.second;
    325 
    326     ShapeIndex index = {0};
    327     for (int64 i : operand_index) {
    328       index.push_back(i);
    329     }
    330 
    331     HloValueSet& value_set = GetValueSet(send, index);
    332     if (value_set != operand_value_set) {
    333       value_set = operand_value_set;
    334       changed = true;
    335     }
    336   }
    337   return changed;
    338 }
    339 
    340 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
    341   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
    342   bool changed = false;
    343   // RecvDone forwards the operand value at {0} to element {0} of its output.
    344   for (auto& pair : GetInstructionValueSet(recv_done)) {
    345     ShapeIndex& index = pair.first;
    346     HloValueSet& value_set = pair.second;
    347 
    348     if (index.empty() || index[0] != 0) {
    349       continue;
    350     }
    351 
    352     const HloValueSet& operand_value_set =
    353         GetValueSet(recv_done->operand(0), index);
    354     if (value_set != operand_value_set) {
    355       value_set = operand_value_set;
    356       changed = true;
    357     }
    358   }
    359   return changed;
    360 }
    361 
    362 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
    363   CHECK_EQ(call->opcode(), HloOpcode::kCall);
    364   InstructionValueSet& value_set = GetInstructionValueSet(call);
    365   InstructionValueSet& root_value_set =
    366       GetInstructionValueSet(call->to_apply()->root_instruction());
    367   if (value_set != root_value_set) {
    368     value_set = root_value_set;
    369     return true;
    370   }
    371   return false;
    372 }
    373 
    374 bool HloDataflowAnalysis::UpdateConditionalValueSet(
    375     HloInstruction* conditional) {
    376   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
    377   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
    378   for (int j = 0; j < conditional->branch_count(); ++j) {
    379     inputs[j] = &GetInstructionValueSet(
    380         conditional->branch_computation(j)->root_instruction());
    381   }
    382   if (ssa_form_) {
    383     return Phi(conditional, inputs);
    384   } else {
    385     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
    386   }
    387 }
    388 
    389 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
    390   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
    391   bool changed = false;
    392   for (auto& pair : GetInstructionValueSet(copy)) {
    393     const ShapeIndex& index = pair.first;
    394     if (index.empty()) {
    395       // kCopy shallow copies and thus defines the top-level value so nothing to
    396       // update.
    397       continue;
    398     }
    399 
    400     HloValueSet& value_set = pair.second;
    401     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
    402     if (value_set != operand_value_set) {
    403       value_set = operand_value_set;
    404       changed = true;
    405     }
    406   }
    407   return changed;
    408 }
    409 
    410 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
    411   // Domain instructions just forward their operand. Given that domains can have
    412   // a tuple operand, we iterate through its indexes, like for copies.
    413   // Unlike copies though we also propagate the top-level value.
    414   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
    415   bool changed = false;
    416   for (auto& pair : GetInstructionValueSet(domain)) {
    417     const ShapeIndex& index = pair.first;
    418     HloValueSet& value_set = pair.second;
    419     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
    420     if (value_set != operand_value_set) {
    421       value_set = operand_value_set;
    422       changed = true;
    423     }
    424   }
    425   return changed;
    426 }
    427 
    428 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
    429     HloInstruction* add_dependency) {
    430   // AddDependency just forwards the value of its zero-th operand.
    431   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
    432   const InstructionValueSet& operand_set =
    433       GetInstructionValueSet(add_dependency->operand(0));
    434   InstructionValueSet& add_dependency_set =
    435       GetInstructionValueSet(add_dependency);
    436   if (operand_set != add_dependency_set) {
    437     add_dependency_set = operand_set;
    438     return true;
    439   }
    440   return false;
    441 }
    442 
    443 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
    444   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
    445   bool changed = false;
    446   // The GetTupleElement instruction forwards the values from the specified
    447   // tuple element.
    448   for (auto& pair : GetInstructionValueSet(gte)) {
    449     const ShapeIndex& index = pair.first;
    450     HloValueSet& value_set = pair.second;
    451 
    452     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
    453     // with the tuple element number prefixed.
    454     ShapeIndex operand_index = {gte->tuple_index()};
    455     for (int64 i : index) {
    456       operand_index.push_back(i);
    457     }
    458 
    459     HloValueSet& operand_value_set =
    460         GetValueSet(gte->operand(0), operand_index);
    461     if (value_set != operand_value_set) {
    462       value_set = operand_value_set;
    463       changed = true;
    464     }
    465   }
    466   return changed;
    467 }
    468 
    469 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
    470   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
    471   const CallGraphNode& call_graph_node =
    472       call_graph_->GetNode(parameter->parent());
    473 
    474   // Subcomputations called in a parallel context (eg, map) do not have dataflow
    475   // from the caller operands.
    476   if (call_graph_node.context() == CallContext::kParallel ||
    477       call_graph_node.caller_callsites().empty()) {
    478     return false;
    479   }
    480   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
    481 
    482   std::vector<const InstructionValueSet*> inputs;
    483   bool need_phi = false;
    484   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    485     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
    486       // The operand values of a call instruction are forwarded to the
    487       // respective parameter instruction of the subcomputation.
    488       inputs.push_back(&GetInstructionValueSet(
    489           callsite.instruction()->operand(parameter->parameter_number())));
    490     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    491       // In a while instruction, the while operand (ie, the init value) and the
    492       // backedge are dataflow inputs to the parameter instruction. This is the
    493       // case for parameters of both the body and condition computations.
    494       CHECK_EQ(parameter->parameter_number(), 0);
    495       inputs.push_back(
    496           &GetInstructionValueSet(callsite.instruction()->operand(0)));
    497       // If the parameter *is* the root, then don't consider it's current state
    498       // (InstructionValueSet) as we are recomputing its current
    499       // state. Otherwise, the parameter state would never be updated.
    500       if (parameter !=
    501           callsite.instruction()->while_body()->root_instruction()) {
    502         inputs.push_back(&GetInstructionValueSet(
    503             callsite.instruction()->while_body()->root_instruction()));
    504       }
    505       need_phi = true;
    506     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
    507       CHECK_EQ(parameter->parameter_number(), 0);
    508       auto conditional = callsite.instruction();
    509       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
    510       // operands 1 and onward are the arguments to the branch computations.
    511       //
    512       // If the parameter belongs to conditional's branch 0 computation, then
    513       // operand 1 is forwarded to this parameter instruction. If the parameter
    514       // belongs to conditional's branch 5 computation, then operand 6 is
    515       // forwarded to this parameter instruction.
    516       bool found_parent = false;
    517       for (int j = 0; j < conditional->branch_count(); ++j) {
    518         if (parameter->parent() == conditional->branch_computation(j)) {
    519           inputs.push_back(
    520               &GetInstructionValueSet(conditional->operand(j + 1)));
    521           found_parent = true;
    522           break;
    523         }
    524       }
    525       CHECK(found_parent);
    526       need_phi = true;
    527     } else {
    528       LOG(FATAL) << "CallContext::kSequential computations should only be "
    529                     "called from call, while, or conditional instructions";
    530     }
    531   }
    532 
    533   if (ssa_form_ && need_phi) {
    534     return Phi(parameter, inputs);
    535   } else {
    536     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
    537   }
    538 }
    539 
    540 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
    541   CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
    542   // A phi value is not defined at a kTupleSelect instruction because
    543   // kTupleSelect does not create a new value. Rather it forwards a value from
    544   // its operands. This contrasts with kWhile instruction (which does define a
    545   // phi value) which has in-place update semantics.
    546   bool changed = false;
    547   for (auto& pair : GetInstructionValueSet(select)) {
    548     const ShapeIndex& index = pair.first;
    549     if (index.empty()) {
    550       // kTupleSelect copies (not forwards) the top-level value.
    551       continue;
    552     }
    553     HloValueSet& value_set = pair.second;
    554     changed |=
    555         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
    556                                  &GetValueSet(select->operand(2), index)});
    557   }
    558   return changed;
    559 }
    560 
    561 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
    562   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
    563   bool changed = false;
    564   for (int64 i = 0; i < tuple->operands().size(); ++i) {
    565     // Copy the value set(s) of each operand into the respective position in the
    566     // kTuple instruction's value sets.
    567     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
    568       const ShapeIndex& operand_index = pair.first;
    569       HloValueSet& operand_value_set = pair.second;
    570 
    571       ShapeIndex index = {i};
    572       for (int64 op_index : operand_index) {
    573         index.push_back(op_index);
    574       }
    575       HloValueSet& value_set = GetValueSet(tuple, index);
    576 
    577       if (value_set != operand_value_set) {
    578         value_set = operand_value_set;
    579         changed = true;
    580       }
    581     }
    582   }
    583   return changed;
    584 }
    585 
    586 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
    587   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
    588   const InstructionValueSet* const inputs[] = {
    589       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
    590       &GetInstructionValueSet(xla_while->operand(0))};
    591   if (ssa_form_) {
    592     return Phi(xla_while, inputs);
    593   } else {
    594     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
    595   }
    596 }
    597 
    598 bool HloDataflowAnalysis::UpdateInstructionValueSet(
    599     HloInstruction* instruction) {
    600   // Recompute from operands.
    601   switch (instruction->opcode()) {
    602     case HloOpcode::kAddDependency:
    603       return UpdateAddDependencyValueSet(instruction);
    604     case HloOpcode::kBitcast:
    605       return UpdateBitcastValueSet(instruction);
    606     case HloOpcode::kDomain:
    607       return UpdateDomainValueSet(instruction);
    608     case HloOpcode::kCopy:
    609       return UpdateCopyValueSet(instruction);
    610     case HloOpcode::kGetTupleElement:
    611       return UpdateGetTupleElementValueSet(instruction);
    612     case HloOpcode::kTupleSelect:
    613       return UpdateTupleSelectValueSet(instruction);
    614     case HloOpcode::kTuple:
    615       return UpdateTupleValueSet(instruction);
    616     case HloOpcode::kParameter:
    617       return UpdateParameterValueSet(instruction);
    618     case HloOpcode::kCall:
    619       return UpdateCallValueSet(instruction);
    620     case HloOpcode::kWhile:
    621       return UpdateWhileValueSet(instruction);
    622     case HloOpcode::kSend:
    623       return UpdateSendValueSet(instruction);
    624     case HloOpcode::kRecvDone:
    625       return UpdateRecvDoneValueSet(instruction);
    626     case HloOpcode::kConditional:
    627       return UpdateConditionalValueSet(instruction);
    628     default:
    629       // Instruction does not forward HloValues (it defines all values in its
    630       // output). No update is necessary.
    631       return false;
    632   }
    633 }
    634 
    635 void HloDataflowAnalysis::Propagate() {
    636   std::queue<HloInstruction*> worklist;
    637   absl::flat_hash_set<HloInstruction*> workset;
    638   auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
    639     if (workset.insert(instruction).second) {
    640       worklist.push(instruction);
    641     }
    642   };
    643 
    644   for (HloComputation* computation : module_.computations()) {
    645     for (HloInstruction* instruction : computation->instructions()) {
    646       add_to_worklist(instruction);
    647     }
    648   }
    649 
    650   while (!worklist.empty()) {
    651     HloInstruction* instruction = worklist.front();
    652     worklist.pop();
    653     workset.erase(workset.find(instruction));
    654 
    655     VLOG(3) << "Worklist top: " << instruction->name();
    656     VLOG(3) << ToString();
    657 
    658     if (!UpdateInstructionValueSet(instruction)) {
    659       // No change to the instruction's value set.
    660       VLOG(4) << "No change.";
    661       continue;
    662     }
    663 
    664     VLOG(4) << "New value set for " << instruction->name() << ": "
    665             << GetInstructionValueSet(instruction);
    666 
    667     // Instruction value was updated. Add users to work list if we haven't
    668     // already.
    669     for (HloInstruction* user : instruction->users()) {
    670       add_to_worklist(user);
    671 
    672       // If user sequentially calls a computation, then the respective
    673       // parameter(s) of the computation need to be updated.
    674       if (user->opcode() == HloOpcode::kConditional) {
    675         // If operand 0 is the use of instruction, then no parameters need to be
    676         // updated, since that is the branch_index of the conditional.
    677         // If operand n+1 is the use of instruction, then the branch_computation
    678         // n's parameter need to be updated.
    679         //
    680         // Note that the same instruction can be used in multiple branches'
    681         // operands.
    682         for (int j = 0; j < user->branch_count(); ++j) {
    683           if (user->operand(j + 1) == instruction) {
    684             add_to_worklist(
    685                 user->branch_computation(j)->parameter_instruction(0));
    686           }
    687         }
    688       } else {
    689         for (HloComputation* called_computation : user->called_computations()) {
    690           const CallGraphNode& call_graph_node =
    691               call_graph_->GetNode(called_computation);
    692           if (call_graph_node.context() == CallContext::kSequential) {
    693             for (int64 operand_number : user->OperandIndices(instruction)) {
    694               add_to_worklist(
    695                   called_computation->parameter_instruction(operand_number));
    696             }
    697           }
    698         }
    699       }
    700     }
    701 
    702     // If instruction is a root instruction, then propagate out to any calling
    703     // instruction and across any while backedge.
    704     if (instruction == instruction->parent()->root_instruction()) {
    705       const CallGraphNode& call_graph_node =
    706           call_graph_->GetNode(instruction->parent());
    707       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
    708         if (callsite.instruction()->opcode() == HloOpcode::kCall ||
    709             callsite.instruction()->opcode() == HloOpcode::kConditional) {
    710           add_to_worklist(callsite.instruction());
    711         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
    712           // Add the while itself, and the body and condition parameters.
    713           add_to_worklist(callsite.instruction());
    714           add_to_worklist(
    715               callsite.instruction()->while_body()->parameter_instruction(0));
    716           add_to_worklist(
    717               callsite.instruction()->while_condition()->parameter_instruction(
    718                   0));
    719         }
    720       }
    721     }
    722   }
    723 }
    724 
    725 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    726     const HloInstruction* instruction) const {
    727   return value_sets_.at(instruction);
    728 }
    729 
    730 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
    731     const HloInstruction* instruction) {
    732   return value_sets_.at(instruction);
    733 }
    734 
    735 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
    736   for (const HloComputation* computation : module_.computations()) {
    737     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
    738     for (HloInstruction* instruction : computation->instructions()) {
    739       // Create an empty shape tree.
    740       value_sets_.emplace(std::piecewise_construct,
    741                           std::forward_as_tuple(instruction),
    742                           std::forward_as_tuple(instruction->shape()));
    743 
    744       // Lambda to set the value set to define all values in the output of the
    745       // instruction.
    746       auto define_all_values = [this, &instruction](bool is_phi = false) {
    747         for (auto& pair : GetInstructionValueSet(instruction)) {
    748           const ShapeIndex& index = pair.first;
    749           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    750           GetValueSet(instruction, index).AddValue(value);
    751         }
    752       };
    753 
    754       // Lambda to set the value set to define only the top-level buffer in the
    755       // output of the instruction. Any other values flow from the operands of
    756       // the instruction (or from cross-computation dataflow).
    757       auto define_top_level_only = [this, &instruction]() {
    758         HloValue* value =
    759             NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false);
    760         GetValueSet(instruction, /*index=*/{}).AddValue(value);
    761       };
    762 
    763       // Lambda to set the value set at the given index of the output.
    764       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
    765         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
    766         GetValueSet(instruction, index).AddValue(value);
    767       };
    768 
    769       switch (instruction->opcode()) {
    770         case HloOpcode::kBitcast:
    771           if (bitcast_defines_value_) {
    772             define_all_values();
    773           }
    774           break;
    775         case HloOpcode::kAddDependency:
    776         case HloOpcode::kWhile:
    777         case HloOpcode::kCall:
    778         case HloOpcode::kConditional:
    779         case HloOpcode::kGetTupleElement:
    780         case HloOpcode::kDomain:
    781           // These instructions define no values. The values in their output
    782           // flow from their operands or from cross computation dataflow.
    783           break;
    784         case HloOpcode::kParameter:
    785           if (call_graph_node.context() == CallContext::kBoth) {
    786             // We do not support a subcomputation that is called from both a
    787             // parallel and sequential context. In this case, the parameter
    788             // would both define a value and propagate a value from its
    789             // caller. This limitation is not really a problem because the call
    790             // graph is typically flattened.
    791             return Unimplemented(
    792                 "Computation %s is called in both a parallel (eg, kMap) and "
    793                 "sequential (eg, kCall) context",
    794                 computation->name());
    795           }
    796           if (call_graph_node.caller_callsites().empty() ||
    797               call_graph_node.context() == CallContext::kParallel) {
    798             // Parameters of computations called in a parallel context (eg, map
    799             // and reduce) as well as parameters of dead computations define all
    800             // values in their output. Otherwise the values of the parameter
    801             // come from the caller (eg, operands to the kCall instruction).
    802             define_all_values();
    803           }
    804           break;
    805         case HloOpcode::kCopy:
    806         case HloOpcode::kTupleSelect:
    807         case HloOpcode::kTuple:
    808           // These instructions only define their top-level values. Any other
    809           // values flow from their operands.
    810           define_top_level_only();
    811           break;
    812         case HloOpcode::kRecvDone:
    813           // RecvDone produces a two-element tuple. Element zero aliases its
    814           // input tuple element {0}; element one is a token.
    815           define_value_at(/*index=*/{});
    816           define_value_at(/*index=*/{1});
    817           break;
    818         case HloOpcode::kSend:
    819           // Send produces a tuple of {aliased operand, U32 context, token},
    820           // therefore only defines the top-level tuple and the tuple elements
    821           // at {1} and {2}.
    822           define_value_at(/*index=*/{});
    823           define_value_at(/*index=*/{1});
    824           define_value_at(/*index=*/{2});
    825           break;
    826         default:
    827           define_all_values();
    828           break;
    829       }
    830     }
    831   }
    832 
    833   return Status::OK();
    834 }
    835 
    836 /* static */
    837 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
    838     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
    839     const FusionCanShareBufferFunction& fusion_can_share_buffer) {
    840   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
    841   XLA_VLOG_LINES(2, module.ToString());
    842 
    843   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
    844       module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
    845 
    846   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
    847   dataflow_analysis->Propagate();
    848 
    849   // Delete all values marked for deletion.
    850   dataflow_analysis->DeleteMarkedValues();
    851 
    852   // Gather and set all non-definition positions of all values. Value deletion
    853   // is rare, so just use a vector indexed by Value::Id rather than a map from
    854   // Value::Id to positions. There should be very few holes in the vector, and
    855   // lookup is faster.
    856   std::vector<std::vector<HloPosition>> value_positions(
    857       dataflow_analysis->next_value_id_);
    858   for (const HloComputation* computation : module.computations()) {
    859     for (HloInstruction* instruction : computation->instructions()) {
    860       for (const auto& pair :
    861            dataflow_analysis->GetInstructionValueSet(instruction)) {
    862         const ShapeIndex& index = pair.first;
    863         const HloValueSet& value_set = pair.second;
    864         for (const HloValue* value : value_set.values()) {
    865           if (value->defining_instruction() != instruction) {
    866             value_positions[value->id()].push_back(
    867                 HloPosition{instruction, index});
    868           }
    869         }
    870       }
    871     }
    872   }
    873   for (auto& pair : dataflow_analysis->values_) {
    874     HloValue::Id value_id = pair.first;
    875     HloValue& value = pair.second;
    876     value.SetPositionsAndComputeUses(value_positions[value_id]);
    877   }
    878 
    879   // Construct vector of values.
    880   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
    881   for (auto& pair : dataflow_analysis->values_) {
    882     dataflow_analysis->values_vector_.push_back(&pair.second);
    883   }
    884   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
    885 
    886   TF_DCHECK_OK(dataflow_analysis->Verify());
    887 
    888   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
    889 
    890   return std::move(dataflow_analysis);
    891 }
    892 
    893 Status HloDataflowAnalysis::Verify() const {
    894   // Verify each HloValue appears in the value sets that the value's positions()
    895   // indicate.
    896   for (const HloValue* value : values()) {
    897     for (const HloPosition& position : value->positions()) {
    898       const HloValueSet& value_set = GetValueSet(position);
    899       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
    900           << "Value set at position " << position << " does not contain value "
    901           << value->ToShortString();
    902     }
    903   }
    904 
    905   // For each value in each value set, verify that the value set's position
    906   // appears in the value's positions().
    907   for (const auto& computation : module_.computations()) {
    908     for (const auto& instruction : computation->instructions()) {
    909       for (const auto& pair : GetInstructionValueSet(instruction)) {
    910         const ShapeIndex& index = pair.first;
    911         const HloValueSet& value_set = pair.second;
    912         const HloPosition position{instruction, index};
    913         for (const HloValue* value : value_set.values()) {
    914           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
    915               << "Value set at position " << position
    916               << " unexpectedly contains value " << value->ToShortString();
    917         }
    918       }
    919     }
    920   }
    921 
    922   return Status::OK();
    923 }
    924 
    925 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
    926     const HloInstruction* operand, const ShapeIndex& index,
    927     const HloInstruction* user) const {
    928   // Return false if no value at 'operand' and 'index' is used at 'user'.
    929   for (const HloValue* value : GetValueSet(operand, index).values()) {
    930     for (const HloUse& use : value->uses()) {
    931       if (use.instruction == user) {
    932         if (user->opcode() == HloOpcode::kFusion &&
    933             user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
    934           HloInstruction* fusion_param =
    935               user->fused_parameter(use.operand_number);
    936           const HloValue& value =
    937               GetValueDefinedAt(fusion_param, use.operand_index);
    938           return value.uses().empty();
    939         }
    940         return false;
    941       }
    942     }
    943   }
    944   return true;
    945 }
    946 
    947 // Given a fusion whose root is a dynamic-update-slice op, determines whether
    948 // the fusion's output buffer can be shared with the buffer of fusion_param,
    949 // which must be a fused parameter of the fusion.
    950 //
    951 // Preconditions:
    952 //
    953 //  - fusion's root is a dynamic-update-slice op.
    954 //  - fusion_param is a parameter within the fusion.
    955 //
    956 // fusion_param may point to a subelement of the actual parameter instruction if
    957 // the param is a tuple; i.e. fusion_param->index() need not be the empty list.
    958 //
    959 // Returns true if:
    960 //
    961 //  * fusion is a loop or input fusion, AND
    962 //  * fusion_param is used by the root of dynamic-update-slice as the "base" of
    963 //    the update, i.e. the thing being updated, AND
    964 //  * all other uses of fusion_param are dynamic-slices that slice the same
    965 //    indices as are overwritten in the dynamic-update-slice.
    966 //
    967 // In the case that there are no other uses of fusion_param (last bullet point
    968 // is vacuously true) it's easy to see why an in-place DUS is safe; this is just
    969 // the "natural" implementation of DUS.  If there are other users, in-place DUS
    970 // is safe on the assumption that the thread which writes element i of the
    971 // output will be the only one to read element i of fusion_param (via the
    972 // dynamic-slice ops).
    973 static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
    974                                            const HloValue& fusion_param_value) {
    975   auto* root =
    976       Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
    977   auto* fusion_param = fusion_param_value.instruction();
    978   CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
    979   CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
    980 
    981   // fusion must be a loop or input fusion.
    982   auto kind = fusion->fusion_kind();
    983   if (kind != HloInstruction::FusionKind::kLoop &&
    984       kind != HloInstruction::FusionKind::kInput) {
    985     return false;
    986   }
    987 
    988   // fusion_param must be used by the root as the "base" of the
    989   // dynamic-update-slice.  The natural way to check this would be
    990   //
    991   //   `if (root->operand(0) != fusion_param)`
    992   //
    993   // but we also have to handle the case where the fusion parameter is
    994   // tuple-shaped and we're considering just one element of that tuple, i.e.
    995   // fusion_param.index() != {}.
    996   if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
    997         return use.instruction == root;
    998       }) != 1) {
    999     return false;
   1000   }
   1001 
   1002   // All other uses of fusion_param must be dynamic-slices that slice the same
   1003   // indices as are overwritten by the dynamic-update-slice.
   1004   for (const HloUse& use : fusion_param_value.uses()) {
   1005     auto* user = use.instruction;
   1006     if (user == root) {
   1007       continue;
   1008     }
   1009 
   1010     // Check that `user` is a dynamic-slice op and has the same slice indices as
   1011     // `root`.
   1012     auto* ds = DynCast<HloDynamicSliceInstruction>(user);
   1013     if (!ds || ds->index_operands() != root->index_operands()) {
   1014       return false;
   1015     }
   1016   }
   1017   return true;
   1018 }
   1019 
   1020 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
   1021     HloInstruction* operand, const ShapeIndex& operand_index,
   1022     HloInstruction* user, const ShapeIndex& user_index) const {
   1023   CHECK(user->IsUserOf(operand))
   1024       << "user: " << user->ToString() << " operand: " << operand->ToString();
   1025   const Shape& operand_subshape =
   1026       ShapeUtil::GetSubshape(operand->shape(), operand_index);
   1027   const Shape& user_subshape =
   1028       ShapeUtil::GetSubshape(user->shape(), user_index);
   1029 
   1030   // Check that operand and user emit the same shape and layout.
   1031   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
   1032     return false;
   1033   }
   1034 
   1035   if (user->opcode() == HloOpcode::kFusion) {
   1036     // Get the parameter associated with 'operand';
   1037     HloInstruction* fusion_param =
   1038         user->fused_parameter(user->operand_index(operand));
   1039 
   1040     const HloValue& fusion_param_value =
   1041         GetValueDefinedAt(fusion_param, operand_index);
   1042 
   1043     // TODO(b/80315712): This code is in a bit of a weird intermediate state
   1044     // at the moment. The in-place DUS check really needs to be common to all
   1045     // backends, so it runs first. Then we run the backend-specific check if
   1046     // provided, or go through the target-indepdendent check if not.
   1047     // Unfortunately, the notionally "target-independent" path actually contains
   1048     // some target-specific code, so we can't run all of it *in addition* to the
   1049     // target-specific function, like the interface documentation says.
   1050     if (user->fused_expression_root()->opcode() ==
   1051         HloOpcode::kDynamicUpdateSlice) {
   1052       return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
   1053     }
   1054 
   1055     if (fusion_can_share_buffer_ != nullptr) {
   1056       return fusion_can_share_buffer_(user, operand);
   1057     }
   1058 
   1059     if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
   1060         user->fusion_kind() == HloInstruction::FusionKind::kInput) {
   1061       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
   1062     }
   1063 
   1064     if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
   1065         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
   1066       // Output fusion with kAdd fused root.
   1067 
   1068       // Check if one operand of kAdd fused root is kDot or kConvolution.
   1069       auto* add = user->fused_expression_root();
   1070       auto add_operand_it =
   1071           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
   1072             return operand->opcode() == HloOpcode::kConvolution ||
   1073                    operand->opcode() == HloOpcode::kDot;
   1074           });
   1075       if (add_operand_it == add->operands().end()) {
   1076         return false;
   1077       }
   1078       auto* matched_add_operand = *add_operand_it;
   1079       // Calculate operand index of 'add' operand which was not matched above.
   1080       const int64 other_add_operand_index =
   1081           matched_add_operand == add->operand(0) ? 1 : 0;
   1082       // Returns true iff there is exactly one use of 'operand' at shape index
   1083       // 'operand_index', and this singleton use is the fused root (at operand
   1084       // index 'other_add_operand_index').
   1085       if (fusion_param_value.uses().size() == 1) {
   1086         const HloUse& use = fusion_param_value.uses()[0];
   1087         return use.instruction == user->fused_expression_root() &&
   1088                use.operand_number == other_add_operand_index;
   1089       }
   1090       return false;
   1091     }
   1092   }
   1093 
   1094   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
   1095       user->opcode() == HloOpcode::kScatter ||
   1096       user->opcode() == HloOpcode::kWhile) {
   1097     // We eliminated other users in BufferLiveness::live_range_strictly_before,
   1098     // so here we just need to check that the use is at operand index 0.
   1099     std::vector<int64> operand_indices = user->OperandIndices(operand);
   1100     return operand_indices.size() == 1 && operand_indices[0] == 0;
   1101   }
   1102   if (user->opcode() == HloOpcode::kSort) {
   1103     // Only valid if there are no other users.
   1104     if (operand->users().size() != 1) {
   1105       return false;
   1106     }
   1107     // If we only sort keys, the output of sort is not a tuple, so we can always
   1108     // share the buffer.
   1109     if (user->operand_count() == 1) {
   1110       return true;
   1111     }
   1112     CHECK(!user_index.empty());
   1113     // Only share with the right tuple element buffer.
   1114     std::vector<int64> operand_indices = user->OperandIndices(operand);
   1115     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
   1116   }
   1117   if (user->opcode() == HloOpcode::kCall) {
   1118     // Get all uses of value defined by 'operand' at 'operand_index'.
   1119     const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
   1120     // Return true iff:
   1121     // *) There exists two uses of 'operand'.
   1122     // *) One use is by 'user' (caller).
   1123     // *) One use is by root instruction of called computation (callee root).
   1124     //    (Note: we check the root of the called computation, because the
   1125     //     root result buffer is required to alias with the Call result buffer).
   1126     // *) The root instruction of the called computation is element-wise on
   1127     //    'operand'.
   1128     const bool found_caller_use =
   1129         absl::c_find_if(uses, [user](const HloUse& use) {
   1130           return use.instruction == user;
   1131         }) != uses.end();
   1132     auto* callee_root = user->to_apply()->root_instruction();
   1133     const bool found_elementwise_callee_use =
   1134         absl::c_find_if(uses, [callee_root](const HloUse& use) {
   1135           return use.instruction == callee_root &&
   1136                  callee_root->IsElementwiseOnOperand(use.operand_number);
   1137         }) != uses.end();
   1138     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
   1139   }
   1140 
   1141   // Loop fusions that contain transposing copies won't reach here as they have
   1142   // different layouts, which fails the check in the beginning of this function.
   1143   return user->IsElementwiseOnOperand(user->operand_index(operand));
   1144 }
   1145 
   1146 }  // namespace xla
   1147