Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
     17 
     18 #include <algorithm>
     19 #include <list>
     20 #include <memory>
     21 #include <numeric>
     22 #include <vector>
     23 
     24 #include "absl/algorithm/container.h"
     25 #include "absl/container/flat_hash_map.h"
     26 #include "absl/container/flat_hash_set.h"
     27 #include "absl/memory/memory.h"
     28 #include "tensorflow/compiler/xla/map_util.h"
     29 #include "tensorflow/compiler/xla/service/fusion_queue.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 
     36 namespace xla {
     37 namespace {
     38 // These nodes can always be duplicated into consumers, even if
     39 // InstructionFusion::may_duplicate_ is false.
     40 //
     41 // In general these should be nodes that get *cheaper* the more they're
     42 // duplicated (and fused into consumers).
     43 //
     44 // TODO(jlebar): Duplicating instructions when we have a variable called "may
     45 // duplicate" that's equal to false is not pretty.
     46 bool IsAlwaysDuplicable(const HloInstruction& instruction) {
     47   // We are always willing to duplicate a widening type-conversion instruction
     48   // if it means we can fuse the convert into a consumer.  This allows the
     49   // consumer to read less memory, which is almost always a performance win.
     50   return instruction.opcode() == HloOpcode::kConvert &&
     51          ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
     52              ShapeUtil::ByteSizeOf(instruction.shape());
     53 }
     54 }  // namespace
     55 
     56 /*static*/ bool InstructionFusion::IsExpensive(
     57     const HloInstruction& instruction) {
     58   switch (instruction.opcode()) {
     59     // Cheap instructions.
     60     case HloOpcode::kAdd:
     61     case HloOpcode::kAnd:
     62     case HloOpcode::kBitcast:
     63     case HloOpcode::kBitcastConvert:
     64     case HloOpcode::kBroadcast:
     65     case HloOpcode::kCeil:
     66     case HloOpcode::kClamp:
     67     case HloOpcode::kClz:
     68     case HloOpcode::kCompare:
     69     case HloOpcode::kComplex:
     70     case HloOpcode::kConcatenate:
     71     case HloOpcode::kConstant:
     72     case HloOpcode::kConvert:
     73     case HloOpcode::kCopy:
     74     case HloOpcode::kDynamicSlice:
     75     case HloOpcode::kDynamicUpdateSlice:
     76     case HloOpcode::kFloor:
     77     case HloOpcode::kGetTupleElement:
     78     case HloOpcode::kImag:
     79     case HloOpcode::kInfeed:
     80     case HloOpcode::kIota:
     81     case HloOpcode::kIsFinite:
     82     case HloOpcode::kMaximum:
     83     case HloOpcode::kMinimum:
     84     case HloOpcode::kMultiply:
     85     case HloOpcode::kNegate:
     86     case HloOpcode::kNot:
     87     case HloOpcode::kOr:
     88     case HloOpcode::kXor:
     89     case HloOpcode::kOutfeed:
     90     case HloOpcode::kPad:
     91     case HloOpcode::kReal:
     92     case HloOpcode::kReducePrecision:
     93     case HloOpcode::kReplicaId:
     94     case HloOpcode::kReshape:
     95     case HloOpcode::kReverse:
     96     case HloOpcode::kRoundNearestAfz:
     97     case HloOpcode::kSelect:
     98     case HloOpcode::kShiftLeft:
     99     case HloOpcode::kShiftRightArithmetic:
    100     case HloOpcode::kShiftRightLogical:
    101     case HloOpcode::kSlice:
    102     case HloOpcode::kSubtract:
    103     case HloOpcode::kTranspose:
    104     case HloOpcode::kTuple:
    105     case HloOpcode::kTupleSelect:
    106       return false;
    107 
    108     // Cheap instructions for reals, but expensive for complex.
    109     case HloOpcode::kAbs:
    110     case HloOpcode::kCos:
    111     case HloOpcode::kSign:
    112     case HloOpcode::kSin:
    113       return ShapeUtil::ElementIsComplex(instruction.shape());
    114 
    115     // Expensive instructions or unusual instructions for which fusion is
    116     // nonsensical.
    117     case HloOpcode::kAddDependency:
    118     case HloOpcode::kAfterAll:
    119     case HloOpcode::kAtan2:
    120     case HloOpcode::kBatchNormGrad:
    121     case HloOpcode::kBatchNormInference:
    122     case HloOpcode::kBatchNormTraining:
    123     case HloOpcode::kCall:
    124     case HloOpcode::kCholesky:
    125     case HloOpcode::kConditional:
    126     case HloOpcode::kConvolution:
    127     case HloOpcode::kAllReduce:
    128     case HloOpcode::kAllToAll:
    129     case HloOpcode::kCollectivePermute:
    130     case HloOpcode::kCustomCall:
    131     case HloOpcode::kDivide:
    132     case HloOpcode::kDomain:
    133     case HloOpcode::kDot:
    134     case HloOpcode::kExp:
    135     case HloOpcode::kExpm1:
    136     case HloOpcode::kFft:
    137     case HloOpcode::kFusion:
    138     case HloOpcode::kGather:
    139     case HloOpcode::kLog:
    140     case HloOpcode::kLog1p:
    141     case HloOpcode::kMap:
    142     case HloOpcode::kParameter:
    143     case HloOpcode::kPower:
    144     case HloOpcode::kRecv:
    145     case HloOpcode::kRecvDone:
    146     case HloOpcode::kReduce:
    147     case HloOpcode::kReduceWindow:
    148     case HloOpcode::kRemainder:
    149     case HloOpcode::kRng:
    150     case HloOpcode::kRsqrt:
    151     case HloOpcode::kScatter:
    152     case HloOpcode::kSelectAndScatter:
    153     case HloOpcode::kSend:
    154     case HloOpcode::kSendDone:
    155     case HloOpcode::kSort:
    156     case HloOpcode::kSqrt:
    157     case HloOpcode::kTanh:
    158     case HloOpcode::kTrace:
    159     case HloOpcode::kTriangularSolve:
    160     case HloOpcode::kWhile:
    161     case HloOpcode::kGetDimensionSize:
    162       return true;
    163   }
    164 
    165   return false;
    166 }
    167 
    168 // An "effectively at most unary" operation is one that has at most one "large"
    169 // input with the others being negligible in terms of memory usage.
    170 // We use "has a smaller true rank than the output" as a heuristic
    171 // for "negligible" memory usage.
    172 bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
    173   int64 output_rank = 0;
    174   ShapeUtil::ForEachSubshape(
    175       hlo->shape(),
    176       [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
    177         if (subshape.IsArray()) {
    178           output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
    179         }
    180       });
    181   return absl::c_count_if(
    182              hlo->operands(), [output_rank](HloInstruction* operand) {
    183                if (operand->opcode() == HloOpcode::kBroadcast ||
    184                    operand->opcode() == HloOpcode::kIota) {
    185                  return false;
    186                }
    187                if (operand->opcode() == HloOpcode::kConstant &&
    188                    ShapeUtil::IsEffectiveScalar(operand->shape())) {
    189                  return false;
    190                }
    191                return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
    192              }) <= 1;
    193 }
    194 
    195 bool InstructionFusion::CanFuseOnAllPaths(
    196     HloInstruction* producer, HloInstruction* consumer,
    197     const HloInstructionSet& do_not_fuse,
    198     absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
    199         result_cache) {
    200   if (consumer == producer) {
    201     return true;
    202   }
    203   if (!consumer->IsFusible()) {
    204     return false;
    205   }
    206   auto cache_it = result_cache->find(std::make_pair(producer, consumer));
    207   if (cache_it != result_cache->end()) {
    208     return cache_it->second;
    209   }
    210   bool result = true;
    211   for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
    212     auto* consumer_operand = consumer->mutable_operand(i);
    213     // If the operand is not on a path to the producer, it doesn't matter
    214     // whether it's fusible.
    215     if (!reachability_->IsReachable(producer, consumer_operand)) {
    216       continue;
    217     }
    218     if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
    219       result = false;
    220       break;
    221     }
    222     // The producer is reachable from consumer_operand which means we need
    223     // to be able to fuse consumer_operand into consumer in order for
    224     // producer to be fusible into consumer on all paths.
    225     // Perform the recursive step: make sure producer can be fused into
    226     // consumer_operand on all paths.
    227     if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
    228                            result_cache)) {
    229       result = false;
    230       break;
    231     }
    232   }
    233   result_cache->emplace(std::make_pair(producer, consumer), result);
    234   return result;
    235 }
    236 
    237 InstructionFusion::HloInstructionSet
    238 InstructionFusion::ComputeGloballyUnfusible(
    239     absl::Span<HloInstruction* const> post_order) {
    240   // Forbid fusion of producers that:
    241   // a) Need to be duplicated, unless they can be fused into all consumers
    242   //    via all paths.
    243   // b) Are more than unary, that is, fusing them would likely lead to an
    244   //    increase in memory bandwidth use.
    245   //
    246   // Note that if we allow fusion by these global rules, we may still forbid
    247   // fusing operations that require duplication later depending on
    248   // is_expensive_().
    249   HloInstructionSet do_not_duplicate;
    250   absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
    251       can_fuse_on_all_paths_result_cache;
    252   for (HloInstruction* consumer : post_order) {
    253     for (HloInstruction* producer : consumer->operands()) {
    254       if (do_not_duplicate.count(producer) > 0) {
    255         continue;
    256       }
    257 
    258       // If the producer is effectively not more than unary, duplicating it
    259       // will not increase the number of relevant inputs read, as the fusion
    260       // node will only need to read at most 1 relevant input (the input of
    261       // the producer). In that case, we do not forbid fusion of the operation
    262       // here.
    263       if (EffectivelyAtMostUnary(producer)) {
    264         continue;
    265       }
    266 
    267       // If the total size of the inputs is less than or equal to the total size
    268       // of the outputs for the producer then duplicating it won't increase the
    269       // memory traffic. In that case, we do not forbid fusion of the operation
    270       // here.
    271       auto total_size = [](const Shape& shape) {
    272         int64 size = 0;
    273         ShapeUtil::ForEachSubshape(
    274             shape,
    275             [&size](const Shape& subshape, const ShapeIndex& shape_index) {
    276               if (subshape.IsArray()) {
    277                 size += ShapeUtil::ElementsIn(subshape);
    278               }
    279             });
    280         return size;
    281       };
    282       int64 operands_size = 0;
    283       for (const HloInstruction* op : producer->operands()) {
    284         operands_size += total_size(op->shape());
    285       }
    286       if (operands_size <= total_size(producer->shape())) {
    287         continue;
    288       }
    289 
    290       // Otherwise we will forbid fusing the op unless we can fuse it into
    291       // all of its consumers on all paths.
    292       //
    293       // That means, that for:
    294       // A --> B (fusible)
    295       //   \-> C (non-fusible)
    296       // A will be not allowed to be fused into B, as it cannot be fused into C.
    297       //
    298       // Similarly, for:
    299       // A -------------> B
    300       //   \-> C -> D -/
    301       // If:
    302       // - A is fusible into B and C, and D is fusible into B
    303       // - C is *not* fusible into D
    304       // A will be not allowed to be fused into B, as it cannot be fused via
    305       // all paths.
    306       if (producer->IsFusible() &&
    307           CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
    308                             &can_fuse_on_all_paths_result_cache)) {
    309         continue;
    310       }
    311       do_not_duplicate.insert(producer);
    312     }
    313   }
    314 
    315   return do_not_duplicate;
    316 }
    317 
    318 namespace {
    319 
    320 // A FusionQueue that uses reverse post order.
    321 //
    322 // We want to be able to remove arbitrary instructions from the post order and
    323 // also compare positions of instructions in the post order. To make this
    324 // possible, create vector of instructions in post order and create a map from
    325 // HloInstruction* to the instruction's index in the vector. An instruction is
    326 // "removed" from the vector by setting it's element to nullptr.
    327 class ReversePostOrderFusionQueue : public FusionQueue {
    328  public:
    329   explicit ReversePostOrderFusionQueue(HloComputation* computation) {
    330     post_order_ = computation->MakeInstructionPostOrder();
    331 
    332     for (size_t i = 0; i < post_order_.size(); ++i) {
    333       InsertOrDie(&post_order_index_, post_order_[i], i);
    334     }
    335   }
    336 
    337   std::pair<HloInstruction*, std::vector<int64>>
    338   DequeueNextInstructionAndOperandsToFuseInOrder() override {
    339     // Instructions are "removed" from the post order by nulling out the element
    340     // in the vector, so if the pointer is null, continue to the next
    341     // instruction in the sort.
    342     while (!post_order_.empty() && post_order_.back() == nullptr) {
    343       post_order_.pop_back();
    344     }
    345     if (post_order_.empty()) {
    346       return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
    347     }
    348     // We want to iterate in reverse post order, so remove from the back of the
    349     // vector.
    350     HloInstruction* instruction = post_order_.back();
    351     post_order_.pop_back();
    352 
    353     CHECK(instruction != nullptr);
    354     // Remove instruction from the index map to ensure the vector and map stay
    355     // consistent.
    356     post_order_index_.erase(instruction);
    357 
    358     // Consider each operand of this instruction for fusion into this
    359     // instruction. We want to consider the operands in a particular order to
    360     // avoid creating duplicate instruction clones in the fusion instruction.
    361     // For example, consider the following expression:
    362     //
    363     //   A = ...
    364     //   B = op(A)
    365     //   C = op(A, B)
    366     //
    367     // If we are considering the operands of C for fusion into C. We might
    368     // fuse A or B first. If we fuse A first, we get:
    369     //
    370     //   A = ...
    371     //   B = op(A)
    372     //   C_fusion = { A' = ...
    373     //                C' = op(A', B) }
    374     //
    375     // Where A' and C' are clones of A and C, respectively. Now only B is an
    376     // operand of the fusion instruction C_fusion, so then we fuse B:
    377     //
    378     //   A = ...
    379     //   B = op(A)
    380     //   C_fusion = { A' = ...
    381     //                B' = op(A)
    382     //                C' = op(A', B') }
    383     //
    384     // Now A is an operand of C_fusion again, so we then fuse A (again!):
    385     //
    386     //   A = ...
    387     //   B = op(A)
    388     //   C_fusion = { A' = ...
    389     //                A" = ..
    390     //                B' = op(A")
    391     //                C' = op(A', B') }
    392     //
    393     // We prevent this duplication by considering the operands in the order
    394     // they appear int the queue. In the example, this ensures that B will be
    395     // considered before A.
    396     //
    397     // We store the original indices of the operands to pass to ShouldFuse.
    398     std::vector<int64> sorted_operand_numbers;
    399     sorted_operand_numbers.reserve(instruction->operands().size());
    400     for (int i = 0; i < instruction->operands().size(); ++i) {
    401       // This will happen if we have two possible instructions to fuse the
    402       // same operand into; once the operand is fused into one instruction,
    403       // the other instruction will get a new get-tuple-element as its
    404       // operand, which is not in the queue.
    405       // TODO(tjoerg): Look into fusing past these multi-output fuse points.
    406       if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
    407         continue;
    408       }
    409       sorted_operand_numbers.push_back(i);
    410     }
    411     absl::c_sort(
    412         sorted_operand_numbers, [&](int64 i, int64 j) {
    413           // Instructions with higher priority in the queue come first.
    414           return (
    415               FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
    416               FindOrDie(post_order_index_, instruction->mutable_operand(j)));
    417         });
    418     return std::make_pair(instruction, sorted_operand_numbers);
    419   }
    420 
    421   void OnFusingInstruction(HloInstruction* fusion,
    422                            HloInstruction* original_producer,
    423                            HloInstruction* original_consumer) override {
    424     // Fusing an instruction into a fusion instruction can change the operand
    425     // set of the fusion instruction. For simplicity just re-enqueue the
    426     // instruction and reconsider it for further fusion in the next iteration.
    427     InsertOrDie(&post_order_index_, fusion, post_order_.size());
    428     post_order_.push_back(fusion);
    429   }
    430 
    431   void RemoveInstruction(HloInstruction* instruction) override {
    432     post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
    433     post_order_index_.erase(instruction);
    434   }
    435 
    436  private:
    437   std::vector<HloInstruction*> post_order_;
    438   absl::flat_hash_map<HloInstruction*, int> post_order_index_;
    439 };
    440 
    441 }  // namespace
    442 
    443 std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
    444     HloComputation* computation) {
    445   return absl::make_unique<ReversePostOrderFusionQueue>(computation);
    446 }
    447 
    448 StatusOr<bool> InstructionFusion::Run(HloModule* module) {
    449   VLOG(2) << "Before instruction fusion:";
    450   XLA_VLOG_LINES(2, module->ToString());
    451 
    452   bool changed = false;
    453   module_ = module;
    454   for (auto* computation : module->MakeNonfusionComputations()) {
    455     CHECK(!computation->IsFusionComputation());
    456     computation_ = computation;
    457     reachability_ = HloReachabilityMap::Build(computation_);
    458 
    459     HloInstructionSet do_not_duplicate;
    460     // If we allow duplications, we need to compute which instructions we do not
    461     // want to duplicate based on a global analysis of the graph.
    462     if (may_duplicate_) {
    463       do_not_duplicate =
    464           ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
    465     }
    466     auto fusion_queue = GetFusionQueue(computation_);
    467 
    468     // Instruction fusion effectively fuses edges in the computation graph
    469     // (producer instruction -> consumer instruction) so we iterate over all
    470     // edges. When we fuse an edge, we create a copy of the producer inside the
    471     // fusion instruction.
    472     while (true) {
    473       auto next_entry =
    474           fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
    475       auto instruction = next_entry.first;
    476       if (instruction == nullptr) {
    477         break;
    478       }
    479 
    480       if (!instruction->IsFusible() &&
    481           instruction->opcode() != HloOpcode::kFusion) {
    482         continue;
    483       }
    484 
    485       std::vector<int64>& sorted_operand_numbers = next_entry.second;
    486 
    487       for (int64 i : sorted_operand_numbers) {
    488         HloInstruction* operand = instruction->mutable_operand(i);
    489 
    490         if (!operand->IsFusible()) {
    491           continue;
    492         }
    493 
    494         HloInstruction* fusion_instruction;
    495         // Try "regular" fusion if the operand may be duplicated. Otherwise,
    496         // perform multi-output fusion, unless this creates a cycle.
    497         if (do_not_duplicate.count(operand) == 0 &&
    498             ShouldFuse(instruction, i)) {
    499           fusion_queue->PreFusion(operand, instruction);
    500           fusion_instruction = Fuse(operand, instruction);
    501         } else if (ShouldFuseIntoMultiOutput(instruction, i) &&
    502                    !MultiOutputFusionCreatesCycle(operand, instruction)) {
    503           fusion_queue->PreFusion(operand, instruction);
    504           fusion_instruction = FuseIntoMultiOutput(operand, instruction);
    505         } else {
    506           continue;
    507         }
    508 
    509         fusion_queue->OnFusingInstruction(fusion_instruction, operand,
    510                                           instruction);
    511         changed = true;
    512 
    513         if (operand->user_count() == 0) {
    514           do_not_duplicate.erase(operand);
    515           // Operand is now dead. Remove from queue.
    516           fusion_queue->RemoveInstruction(operand);
    517           // Remove from computation.
    518           TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
    519         }
    520 
    521         if (fusion_instruction != instruction) {
    522           do_not_duplicate.erase(instruction);
    523         }
    524         break;
    525       }
    526     }
    527   }
    528 
    529   VLOG(2) << "After instruction fusion:";
    530   XLA_VLOG_LINES(2, module->ToString());
    531 
    532   return changed;
    533 }
    534 
    535 HloInstruction* InstructionFusion::AddFusionInstruction(
    536     HloInstruction* producer, HloInstruction* consumer) {
    537   HloInstruction* fusion_instruction;
    538   auto kind = ChooseKind(producer, consumer);
    539   if (consumer->opcode() == HloOpcode::kFusion) {
    540     fusion_instruction = consumer;
    541     if (kind != fusion_instruction->fusion_kind()) {
    542       fusion_instruction->set_fusion_kind(kind);
    543     }
    544   } else {
    545     fusion_instruction = computation_->AddInstruction(
    546         HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
    547     TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
    548   }
    549   return fusion_instruction;
    550 }
    551 
    552 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
    553                                         HloInstruction* consumer) {
    554   VLOG(2) << "Fusing " << producer->ToString() << " into "
    555           << consumer->ToString();
    556   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
    557   fusion_instruction->FuseInstruction(producer);
    558   return fusion_instruction;
    559 }
    560 
    561 HloInstruction* InstructionFusion::FuseIntoMultiOutput(
    562     HloInstruction* producer, HloInstruction* consumer) {
    563   VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
    564           << consumer->ToString();
    565   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
    566   fusion_instruction->FuseInstructionIntoMultiOutput(producer);
    567   return fusion_instruction;
    568 }
    569 
    570 bool InstructionFusion::MultiOutputFusionCreatesCycle(
    571     HloInstruction* producer, HloInstruction* consumer) {
    572   absl::flat_hash_set<int> operands;
    573   for (const HloInstruction* operand : consumer->operands()) {
    574     if (operand == producer) {
    575       continue;
    576     }
    577 
    578     // If the reachability map already contains the producer and the operand of
    579     // the consumer, and the producer can reach the operand, then we know for
    580     // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS
    581     // traversal of the computation to verify that this multioutput fusion would
    582     // not create a cycle.
    583     if (reachability_->IsPresent(producer) &&
    584         reachability_->IsPresent(operand) &&
    585         reachability_->IsReachable(producer, operand)) {
    586       return true;
    587     }
    588     operands.insert(operand->unique_id());
    589   }
    590 
    591   // Do a DFS on the producer to see if any of the other consumer operands are
    592   // reachable in the current state of the graph.
    593   std::vector<HloInstruction*> worklist = producer->users();
    594   absl::flat_hash_set<int> visits;
    595   while (!worklist.empty()) {
    596     const HloInstruction* user = worklist.back();
    597     worklist.pop_back();
    598     if (operands.count(user->unique_id()) != 0) {
    599       return true;
    600     }
    601     if (visits.count(user->unique_id()) == 0) {
    602       visits.insert(user->unique_id());
    603       worklist.insert(worklist.end(), user->users().begin(),
    604                       user->users().end());
    605     }
    606   }
    607   return false;
    608 }
    609 
    610 bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
    611                                    int64 operand_index) {
    612   HloInstruction* producer = consumer->mutable_operand(operand_index);
    613 
    614   // Cost condition: don't duplicate expensive instructions.
    615   if (FusionWouldDuplicate(*producer, *consumer) &&
    616       (!may_duplicate_ || is_expensive_(*producer)) &&
    617       !IsAlwaysDuplicable(*producer)) {
    618     return false;
    619   }
    620 
    621   if (consumer->opcode() == HloOpcode::kFusion &&
    622       consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
    623       consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
    624       consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
    625     return false;
    626   }
    627 
    628   if (producer->CouldBeBitcast() &&
    629       // We can't fuse parameters anyhow, so we leave the user unfused to become
    630       // a bitcast. If the operand is not a parameter, we would break a
    631       // potential fusion to make it a bitcast, which is not so clear a win.
    632       producer->operand(0)->opcode() == HloOpcode::kParameter) {
    633     return false;
    634   }
    635 
    636   return true;
    637 }
    638 
    639 HloInstruction::FusionKind InstructionFusion::ChooseKind(
    640     const HloInstruction* producer, const HloInstruction* consumer) {
    641   return HloInstruction::FusionKind::kLoop;
    642 }
    643 
    644 }  // namespace xla
    645