Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
     17 #include "tensorflow/compiler/xla/service/tuple_util.h"
     18 #include "tensorflow/compiler/xla/service/while_util.h"
     19 #include "tensorflow/compiler/xla/util.h"
     20 #include "tensorflow/core/lib/gtl/flatmap.h"
     21 #include "tensorflow/core/lib/gtl/flatset.h"
     22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     23 
     24 namespace xla {
     25 
     26 using tensorflow::gtl::FlatMap;
     27 using tensorflow::gtl::FlatSet;
     28 using tensorflow::gtl::InlinedVector;
     29 
     30 // Copies `to_hoist` to the computation containing `while_instr`, hoisting its
     31 // operands as needed.  All of its transitive operands are expected to be either
     32 // in `hoisted_instructions` or `unhoisted_invariant_instructions`.  This
     33 // function hoists the operands in `unhoisted_invariant_instructions` and moves
     34 // them into `hoisted_instructions`.
     35 static void CreateLoopInvariantCopy(
     36     FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
     37     FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
     38     HloInstruction* while_instr, HloInstruction* to_hoist) {
     39   HloComputation* parent_of_while = while_instr->parent();
     40   HloComputation* while_body = while_instr->while_body();
     41 
     42   struct DFSFrame {
     43     HloInstruction* instruction;
     44     int64 operand_index;
     45   };
     46 
     47   InlinedVector<DFSFrame, 8> dfs_stack;
     48   dfs_stack.push_back({to_hoist, 0});
     49 
     50   HloInstruction* while_body_param = while_body->parameter_instruction(0);
     51   HloInstruction* while_operand = while_instr->mutable_operand(0);
     52 
     53   do {
     54     DFSFrame* frame = &dfs_stack.back();
     55     if (frame->operand_index == frame->instruction->operand_count()) {
     56       HloInstruction* old_instruction = frame->instruction;
     57 
     58       // All of the operands for old_instruction have been cloned, so it is
     59       // time to clone old_instruction itself.
     60 
     61       auto get_new_operand = [&](HloInstruction* old_operand) {
     62         return old_operand == while_body_param
     63                    ? while_operand
     64                    : FindOrDie(*hoisted_instructions, old_operand);
     65       };
     66 
     67       InlinedVector<HloInstruction*, 4> new_operands;
     68       c_transform(old_instruction->operands(), std::back_inserter(new_operands),
     69                   get_new_operand);
     70 
     71       HloInstruction* new_instruction =
     72           parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
     73               old_instruction->shape(), new_operands));
     74 
     75       InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
     76 
     77       // Approximately half of the instructions that would normally be present
     78       // in unhoisted_invariant_instructions are constants.  We save a bit of
     79       // compile time by not putting these in the hashtable.
     80       CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction),
     81                to_hoist != old_instruction &&
     82                    old_instruction->opcode() != HloOpcode::kConstant);
     83       dfs_stack.pop_back();
     84       continue;
     85     }
     86 
     87     HloInstruction* next_operand =
     88         frame->instruction->mutable_operand(frame->operand_index++);
     89     if (hoisted_instructions->count(next_operand) ||
     90         next_operand == while_body_param) {
     91       continue;
     92     }
     93 
     94     dfs_stack.push_back({next_operand, 0});
     95   } while (!dfs_stack.empty());
     96 }
     97 
     98 // Returns true if `instruction` is worth hoisting only if it lets us hoist some
     99 // instruction using it.  The rationale is that hoisting these instructions will
    100 // prevent simplification and fusion in the while body.
    101 static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
    102   switch (instruction.opcode()) {
    103     default:
    104       return false;
    105 
    106     case HloOpcode::kBitcast:
    107     case HloOpcode::kBroadcast:
    108     case HloOpcode::kConstant:
    109     case HloOpcode::kReverse:
    110     case HloOpcode::kSlice:
    111     case HloOpcode::kTuple:
    112       return true;
    113 
    114     case HloOpcode::kTranspose:
    115       return ShapeUtil::TransposeIsBitcast(
    116           /*input_shape=*/instruction.operand(0)->shape(),
    117           /*output_shape=*/instruction.shape(), instruction.dimensions());
    118 
    119     case HloOpcode::kReshape:
    120       return ShapeUtil::ReshapeIsBitcast(
    121           /*input_shape=*/instruction.operand(0)->shape(),
    122           /*output_shape=*/instruction.shape());
    123   }
    124 }
    125 
    126 // Populates `gte_set` with the GetTupleElement instructions in `while_body`
    127 // that access elements in the parameter tuple that don't change across
    128 // iterations.  Assumes `while_body` is the body computation of the while loop
    129 // in question.
    130 static void GatherInvariantGTEs(HloComputation* while_body,
    131                                 FlatSet<HloInstruction*>* gte_set) {
    132   const HloInstruction::InstructionVector root_operands =
    133       while_body->root_instruction()->operands();
    134   for (int i = 0; i < root_operands.size(); i++) {
    135     HloInstruction* instr = root_operands[i];
    136     if (instr->opcode() == HloOpcode::kGetTupleElement &&
    137         instr->tuple_index() == i &&
    138         instr->operand(0) == while_body->parameter_instruction(0) &&
    139         ShapeUtil::IsArray(instr->shape())) {
    140       InsertOrDie(gte_set, instr);
    141     }
    142   }
    143 }
    144 
    145 static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
    146     HloInstruction* while_instr) {
    147   auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
    148 
    149   if (!ShapeUtil::IsTuple(while_instr->shape())) {
    150     // This restriction leaves one interesting pattern on the table:
    151     //
    152     //  while_body(f32[1024, 1024] %param) {
    153     //    %value = expensive_op(%param)
    154     //    outfeed(%value)
    155     //    ROOT = %param
    156     //  }
    157     //
    158     // If we see that pattern in the while, instead of generalizing this
    159     // algorithm to work with non-tuples, we should instead add a pass that
    160     // canonicalizes while loops like the above to use a tuple state.
    161     return false;
    162   }
    163 
    164   string while_instr_name = while_instr->ToString(print_no_metadata);
    165   VLOG(2) << "Trying to hoist from " << while_instr_name;
    166 
    167   HloComputation* while_body = while_instr->while_body();
    168 
    169   // Maps instructions in the while body to instructions hoisted outside the
    170   // while that compute the same value.
    171   FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
    172 
    173   // Contains instructions that can be legally hoisted, but were deemed to be
    174   // unprofitable to be hoisted alone by NotWorthHoistingIndividually.  When we
    175   // hoist an instruction in this set, we move it from
    176   // unhoisted_invariant_instructions to hoisted_instructions.
    177   FlatSet<HloInstruction*> unhoisted_invariant_instructions;
    178 
    179   // Invariant GTE's axiomatically satisfy the constraints for
    180   // unhoisted_invariant_instructions -- they can be legally hoisted, but there
    181   // is no benefit to hoisting them unless something that uses it is also
    182   // hoisted.
    183   GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions);
    184 
    185   if (unhoisted_invariant_instructions.empty()) {
    186     // There are no obviously loop invariant elements in the state being
    187     // threaded through the while loop so give up.  In theory this precondition
    188     // is too strong -- we could have code that e.g. permutes the elements in
    189     // the while state but uses a select to pick the same value on every
    190     // iteration.
    191     return false;
    192   }
    193 
    194   // instructions_to_replace[i] is hoisted into a loop invariant instruction
    195   // replacement_instructions[i].
    196   std::vector<HloInstruction*> instructions_to_replace;
    197   std::vector<HloInstruction*> replacement_instructions;
    198 
    199   for (auto* instruction : while_body->MakeInstructionPostOrder()) {
    200     if (instruction->HasSideEffect() ||
    201         instruction->opcode() == HloOpcode::kParameter ||
    202         !instruction->control_predecessors().empty() ||
    203         !instruction->control_successors().empty()) {
    204       continue;
    205     }
    206 
    207     auto is_invariant = [&](HloInstruction* op) {
    208       return hoisted_instructions.find(op) != hoisted_instructions.end() ||
    209              unhoisted_invariant_instructions.count(op) ||
    210              op->opcode() == HloOpcode::kConstant;
    211     };
    212 
    213     if (!c_all_of(instruction->operands(), is_invariant)) {
    214       continue;
    215     }
    216 
    217     if (NotWorthHoistingIndividually(*instruction)) {
    218       VLOG(2) << "Adding " << instruction->ToString(print_no_metadata)
    219               << " to unhoisted invariant set.";
    220       // Approximately half of the instructions that reach this point are
    221       // constants.  We save a bit of compile time by not putting these in the
    222       // hashtable.
    223       if (instruction->opcode() != HloOpcode::kConstant) {
    224         InsertOrDie(&unhoisted_invariant_instructions, instruction);
    225       }
    226       continue;
    227     }
    228 
    229     VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata);
    230 
    231     CreateLoopInvariantCopy(&hoisted_instructions,
    232                             &unhoisted_invariant_instructions, while_instr,
    233                             instruction);
    234 
    235     instructions_to_replace.push_back(instruction);
    236     replacement_instructions.push_back(
    237         FindOrDie(hoisted_instructions, instruction));
    238   }
    239 
    240   if (instructions_to_replace.empty()) {
    241     return false;
    242   }
    243 
    244   TF_ASSIGN_OR_RETURN(
    245       WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result,
    246       WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions));
    247 
    248   HloComputation* new_while_body =
    249       live_in_instructions_result.new_while_instr->while_body();
    250 
    251   for (int i = 0; i < instructions_to_replace.size(); i++) {
    252     HloInstruction* instruction_to_replace_in_new_while =
    253         FindOrDie(live_in_instructions_result.while_body_instruction_map,
    254                   instructions_to_replace[i]);
    255     TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction(
    256         instruction_to_replace_in_new_while,
    257         live_in_instructions_result.while_body_live_in_values[i]));
    258   }
    259 
    260   VLOG(1) << "Hoisted " << instructions_to_replace.size()
    261           << " instructions from " << while_instr_name;
    262 
    263   return true;
    264 }
    265 
    266 StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
    267   bool changed = false;
    268   std::vector<HloInstruction*> while_instrs;
    269   for (auto* comp : module->computations()) {
    270     c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
    271               [](const HloInstruction* instr) {
    272                 return instr->opcode() == HloOpcode::kWhile;
    273               });
    274   }
    275 
    276   for (HloInstruction* while_instr : while_instrs) {
    277     // Right now we only hoist computations from the while body, but
    278     // TryHoistingInvariantInstructionsFromWhileBody can be generalized to
    279     // optimize the condition computation too, if needed.
    280     //
    281     // The transform we do here is a pessmization for while loops that execute
    282     // zero times*, but at this time we expect those to be rare.  If this
    283     // becomes a problem we can consider using the conditional HLO to avoid
    284     // doing extra work for while loops with zero trip count.
    285     //
    286     // * We delete while loops that have a zero trip count, so this would have
    287     //   to be a while loop with a somewhat opaque condition expression.
    288 
    289     TF_ASSIGN_OR_RETURN(
    290         bool result,
    291         TryHoistingInvariantInstructionsFromWhileBody(while_instr));
    292     changed |= result;
    293   }
    294   return changed;
    295 }
    296 }  // namespace xla
    297