Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
     17 
     18 #include <algorithm>
     19 #include <list>
     20 #include <memory>
     21 #include <numeric>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/map_util.h"
     25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/gtl/flatmap.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 
     30 namespace xla {
     31 /*static*/ bool InstructionFusion::IsExpensive(
     32     const HloInstruction& instruction) {
     33   switch (instruction.opcode()) {
     34     // Cheap instructions.
     35     case HloOpcode::kAdd:
     36     case HloOpcode::kAnd:
     37     case HloOpcode::kBitcast:
     38     case HloOpcode::kBitcastConvert:
     39     case HloOpcode::kBroadcast:
     40     case HloOpcode::kCeil:
     41     case HloOpcode::kClamp:
     42     case HloOpcode::kComplex:
     43     case HloOpcode::kConcatenate:
     44     case HloOpcode::kConstant:
     45     case HloOpcode::kConvert:
     46     case HloOpcode::kCopy:
     47     case HloOpcode::kDynamicSlice:
     48     case HloOpcode::kDynamicUpdateSlice:
     49     case HloOpcode::kEq:
     50     case HloOpcode::kFloor:
     51     case HloOpcode::kGe:
     52     case HloOpcode::kGetTupleElement:
     53     case HloOpcode::kGt:
     54     case HloOpcode::kImag:
     55     case HloOpcode::kInfeed:
     56     case HloOpcode::kIsFinite:
     57     case HloOpcode::kLe:
     58     case HloOpcode::kLt:
     59     case HloOpcode::kMaximum:
     60     case HloOpcode::kMinimum:
     61     case HloOpcode::kMultiply:
     62     case HloOpcode::kNe:
     63     case HloOpcode::kNegate:
     64     case HloOpcode::kNot:
     65     case HloOpcode::kOr:
     66     case HloOpcode::kOutfeed:
     67     case HloOpcode::kPad:
     68     case HloOpcode::kReal:
     69     case HloOpcode::kReducePrecision:
     70     case HloOpcode::kReshape:
     71     case HloOpcode::kReverse:
     72     case HloOpcode::kRoundNearestAfz:
     73     case HloOpcode::kSelect:
     74     case HloOpcode::kShiftLeft:
     75     case HloOpcode::kShiftRightArithmetic:
     76     case HloOpcode::kShiftRightLogical:
     77     case HloOpcode::kSlice:
     78     case HloOpcode::kSubtract:
     79     case HloOpcode::kTranspose:
     80     case HloOpcode::kTuple:
     81       return false;
     82 
     83     // Cheap instructions for reals, but expensive for complex.
     84     case HloOpcode::kAbs:
     85     case HloOpcode::kCos:
     86     case HloOpcode::kSign:
     87     case HloOpcode::kSin:
     88       return ShapeUtil::ElementIsComplex(instruction.shape());
     89 
     90     // Expensive instructions.
     91     case HloOpcode::kAtan2:
     92     case HloOpcode::kBatchNormGrad:
     93     case HloOpcode::kBatchNormInference:
     94     case HloOpcode::kBatchNormTraining:
     95     case HloOpcode::kCall:
     96     case HloOpcode::kConditional:
     97     case HloOpcode::kConvolution:
     98     case HloOpcode::kCrossReplicaSum:
     99     case HloOpcode::kCustomCall:
    100     case HloOpcode::kDivide:
    101     case HloOpcode::kDot:
    102     case HloOpcode::kExp:
    103     case HloOpcode::kFft:
    104     case HloOpcode::kFusion:
    105     case HloOpcode::kGather:
    106     case HloOpcode::kHostCompute:
    107     case HloOpcode::kLog:
    108     case HloOpcode::kMap:
    109     case HloOpcode::kParameter:
    110     case HloOpcode::kPower:
    111     case HloOpcode::kRecv:
    112     case HloOpcode::kRecvDone:
    113     case HloOpcode::kReduce:
    114     case HloOpcode::kReduceWindow:
    115     case HloOpcode::kRemainder:
    116     case HloOpcode::kRng:
    117     case HloOpcode::kSelectAndScatter:
    118     case HloOpcode::kSend:
    119     case HloOpcode::kSendDone:
    120     case HloOpcode::kSort:
    121     case HloOpcode::kTanh:
    122     case HloOpcode::kTrace:
    123     case HloOpcode::kWhile:
    124       return true;
    125   }
    126 
    127   return false;
    128 }
    129 
    130 // An "effectively unary" operation is one that has one "large"
    131 // input with the others being negligible in terms of memory usage.
    132 // We use "has a smaller true rank than the output" as a heuristic
    133 // for "negligible" memory usage.
    134 bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
    135   int64 output_rank = 0;
    136   ShapeUtil::ForEachSubshape(
    137       hlo->shape(),
    138       [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
    139         if (ShapeUtil::IsArray(subshape)) {
    140           output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
    141         }
    142       });
    143   return std::count_if(hlo->operands().begin(), hlo->operands().end(),
    144                        [output_rank](HloInstruction* operand) {
    145                          if (operand->opcode() == HloOpcode::kBroadcast) {
    146                            return false;
    147                          }
    148                          if (operand->opcode() == HloOpcode::kConstant &&
    149                              ShapeUtil::IsEffectiveScalar(operand->shape())) {
    150                            return false;
    151                          }
    152                          return ShapeUtil::TrueRank(operand->shape()) >=
    153                                 output_rank;
    154                        }) <= 1;
    155 }
    156 
    157 bool InstructionFusion::CanFuseOnAllPaths(
    158     const HloReachabilityMap& reachability_map, HloInstruction* producer,
    159     HloInstruction* consumer, DoNotFuseSet* do_not_fuse) {
    160   auto could_fuse_on_all_paths = [&] {
    161     // First check to see if we have already marked this producer as infeasible
    162     // to fuse into consumer.
    163     if (do_not_fuse->count(producer) > 0) {
    164       return false;
    165     }
    166     // Make sure it is possible for producer and consumer to exist in a fusion
    167     // node.
    168     if (!producer->IsFusable() || !consumer->IsFusable()) {
    169       return false;
    170     }
    171     // We do an upward walk of the graph from consumer towards all paths which
    172     // lead to producer to find any unfusable paths.
    173     for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
    174       auto* consumer_operand = consumer->mutable_operand(i);
    175       if (consumer_operand == producer) {
    176         // This is the base case: our upward crawl ends but we need to make sure
    177         // that fusion from consumer can happen.
    178         if (!ShouldFuse(consumer, i)) {
    179           return false;
    180         }
    181       } else if (reachability_map.IsReachable(producer, consumer_operand)) {
    182         // The reachability map told us that consumer_operand is a node on the
    183         // path to producer. We need to further investigate from
    184         // consumer_operand.
    185 
    186         // First check if we have already ruled out fusing producer into
    187         // consumer_operand.
    188         if (do_not_fuse->count(consumer_operand) > 0) {
    189           return false;
    190         }
    191         // Make sure it is possible for consumer_operand to exist in a fusion
    192         // node.
    193         if (!consumer_operand->IsFusable()) {
    194           return false;
    195         }
    196         // The producer is reachable from consumer_operand which means we need
    197         // to be able to fuse consumer_operand into consumer in order for
    198         // producer to be fusable into consumer on all paths.
    199         if (!ShouldFuse(consumer, i)) {
    200           return false;
    201         }
    202         // Perform the recursive step: make sure producer can be fused into
    203         // consumer_operand on all paths.
    204         if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand,
    205                                do_not_fuse)) {
    206           return false;
    207         }
    208       }
    209     }
    210     return true;
    211   };
    212   if (could_fuse_on_all_paths()) {
    213     return true;
    214   }
    215   // We couldn't fuse on all paths, record this result.
    216   do_not_fuse->insert(producer);
    217   return false;
    218 }
    219 
    220 StatusOr<bool> InstructionFusion::Run(HloModule* module) {
    221   VLOG(2) << "Before instruction fusion:";
    222   XLA_VLOG_LINES(2, module->ToString());
    223 
    224   bool changed = false;
    225   module_ = module;
    226   for (auto* computation : module->MakeNonfusionComputations()) {
    227     CHECK(!computation->IsFusionComputation());
    228     computation_ = computation;
    229 
    230     // We want to be able to remove arbitrary instructions from the post order
    231     // and also compare positions of instructions in the post order. To make
    232     // this possible, create vector of instructions in post order and create a
    233     // map from HloInstruction* to the instruction's index in the vector. An
    234     // instruction is "removed" from the vector by setting it's element to
    235     // nullptr.
    236     std::list<HloInstruction*> post_order_list =
    237         computation_->MakeInstructionPostOrder();
    238     std::vector<HloInstruction*> post_order(post_order_list.begin(),
    239                                             post_order_list.end());
    240 
    241     tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
    242     for (size_t i = 0; i < post_order.size(); ++i) {
    243       InsertOrDie(&post_order_index, post_order[i], i);
    244     }
    245 
    246     DoNotFuseSet do_not_fuse;
    247     auto reachability = computation->ComputeReachability();
    248 
    249     auto cheap_to_duplicate = [this](HloInstruction* producer) {
    250       if (producer->opcode() == HloOpcode::kBroadcast) {
    251         return true;
    252       }
    253       if (producer->opcode() == HloOpcode::kConstant &&
    254           ShapeUtil::IsEffectiveScalar(producer->shape())) {
    255         return true;
    256       }
    257       if (EffectivelyUnary(producer)) {
    258         return true;
    259       }
    260       return false;
    261     };
    262 
    263     for (HloInstruction* consumer : post_order) {
    264       for (HloInstruction* producer : consumer->operands()) {
    265         if (cheap_to_duplicate(producer)) {
    266           continue;
    267         }
    268         if (CanFuseOnAllPaths(*reachability, producer, consumer,
    269                               &do_not_fuse)) {
    270           CHECK_EQ(do_not_fuse.count(producer), 0);
    271         } else {
    272           CHECK_GT(do_not_fuse.count(producer), 0);
    273         }
    274       }
    275     }
    276 
    277     // Instruction fusion effectively fuses edges in the computation graph
    278     // (producer instruction -> consumer instruction) so we iterate over all
    279     // edges. When we fuse an edge, we create a copy of the producer inside the
    280     // fusion instruction.
    281     while (!post_order.empty()) {
    282       // We want to iterate in reverse post order, so remove from the back of
    283       // the vector.
    284       HloInstruction* instruction = post_order.back();
    285       post_order.pop_back();
    286 
    287       // Instructions are "removed" from the post order by nulling out the
    288       // element in the vector, so if the pointer is null, continue to the next
    289       // instruction in the sort.
    290       if (instruction == nullptr) {
    291         continue;
    292       }
    293 
    294       // Remove instruction from the index map to ensure the vector and map stay
    295       // consistent.
    296       post_order_index.erase(instruction);
    297 
    298       if (!instruction->IsFusable() &&
    299           instruction->opcode() != HloOpcode::kFusion) {
    300         continue;
    301       }
    302 
    303       // Consider each operand of this instruction for fusion into this
    304       // instruction. We want to consider the operands in a particular order to
    305       // avoid created duplicate instruction clones in the fusion instruction.
    306       // For example, consider the following expression:
    307       //
    308       //   A = ...
    309       //   B = op(A)
    310       //   C = op(A, B)
    311       //
    312       // If we are considering the operands of C for fusion into C. We might
    313       // fuse A or B first. If we fuse A first, we get:
    314       //
    315       //   A = ...
    316       //   B = op(A)
    317       //   C_fusion = { A' = ...
    318       //                C' = op(A', B) }
    319       //
    320       // Where A' and C' are clones of A and C, respectively. Now only B is an
    321       // operand of the fusion instruction C_fusion, so then we fuse B:
    322       //
    323       //   A = ...
    324       //   B = op(A)
    325       //   C_fusion = { A' = ...
    326       //                B' = op(A)
    327       //                C' = op(A', B') }
    328       //
    329       // Now A is an operand of C_fusion again, so we then fuse A (again!):
    330       //
    331       //   A = ...
    332       //   B = op(A)
    333       //   C_fusion = { A' = ...
    334       //                A" = ..
    335       //                B' = op(A")
    336       //                C' = op(A', B') }
    337       //
    338       // We prevent this duplication by considering the operands in the reverse
    339       // order they appear in the instruction post order. In the example, this
    340       // ensures that B will be considered before A.
    341       //
    342       // We store the original indices of the operands to pass to ShouldFuse.
    343       std::vector<int64> sorted_operand_numbers(instruction->operands().size());
    344       std::iota(std::begin(sorted_operand_numbers),
    345                 std::end(sorted_operand_numbers), 0);
    346       std::sort(
    347           sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
    348           [&](int64 i, int64 j) {
    349             // Instructions with higher indices in the post order come
    350             // first.
    351             return (
    352                 FindOrDie(post_order_index, instruction->mutable_operand(i)) >
    353                 FindOrDie(post_order_index, instruction->mutable_operand(j)));
    354           });
    355 
    356       for (int64 i : sorted_operand_numbers) {
    357         HloInstruction* operand = instruction->mutable_operand(i);
    358 
    359         if (!operand->IsFusable()) {
    360           continue;
    361         }
    362         if (!ShouldFuse(instruction, i)) {
    363           continue;
    364         }
    365         if (do_not_fuse.count(operand) > 0) {
    366           continue;
    367         }
    368         HloInstruction* fusion_instruction = Fuse(operand, instruction);
    369 
    370         // Fusing an instruction into a fusion instruction can change the
    371         // operand set of the fusion instruction. For simplicity just push the
    372         // instruction to the top of the post_order and reconsider it for
    373         // further fusion in the next iteration of the outer loop.
    374         post_order.push_back(fusion_instruction);
    375         InsertOrDie(&post_order_index, fusion_instruction,
    376                     post_order.size() - 1);
    377         changed = true;
    378 
    379         if (operand->user_count() == 0) {
    380           // Operand is now dead. Remove from post order by setting it's
    381           // location to nullptr.
    382           post_order[FindOrDie(post_order_index, operand)] = nullptr;
    383           post_order_index.erase(operand);
    384 
    385           // Remove from computation.
    386           TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
    387         }
    388         break;
    389       }
    390     }
    391   }
    392 
    393   VLOG(2) << "After instruction fusion:";
    394   XLA_VLOG_LINES(2, module->ToString());
    395 
    396   return changed;
    397 }
    398 
    399 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
    400                                         HloInstruction* consumer) {
    401   HloInstruction* fusion_instruction;
    402 
    403   VLOG(2) << "Fusing " << producer->ToString() << " into "
    404           << consumer->ToString();
    405   auto kind = ChooseKind(producer, consumer);
    406   if (consumer->opcode() == HloOpcode::kFusion) {
    407     fusion_instruction = consumer;
    408     if (kind != fusion_instruction->fusion_kind()) {
    409       fusion_instruction->set_fusion_kind(kind);
    410     }
    411   } else {
    412     fusion_instruction = computation_->AddInstruction(
    413         HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
    414     TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
    415   }
    416 
    417   fusion_instruction->FuseInstruction(producer);
    418   return fusion_instruction;
    419 }
    420 
    421 bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
    422                                    int64 operand_index) {
    423   HloInstruction* producer = consumer->mutable_operand(operand_index);
    424   // Cost condition: don't duplicate expensive instructions.
    425   if (FusionWouldDuplicate(*producer, *consumer) &&
    426       (is_expensive_(*producer) || !may_duplicate_)) {
    427     return false;
    428   }
    429 
    430   if (consumer->opcode() == HloOpcode::kFusion &&
    431       consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
    432       consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
    433       consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
    434     return false;
    435   }
    436 
    437   if (producer->CouldBeBitcast() &&
    438       // We can't fuse parameters anyhow, so we leave the user unfused to become
    439       // a bitcast. If the operand is not a parameter, we would break a
    440       // potential fusion to make it a bitcast, which is not so clear a win.
    441       producer->operand(0)->opcode() == HloOpcode::kParameter) {
    442     return false;
    443   }
    444 
    445   return true;
    446 }
    447 
    448 HloInstruction::FusionKind InstructionFusion::ChooseKind(
    449     const HloInstruction* producer, const HloInstruction* consumer) {
    450   return HloInstruction::FusionKind::kLoop;
    451 }
    452 
    453 }  // namespace xla
    454