Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
     17 #include "tensorflow/compiler/xla/service/call_inliner.h"
     18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     19 #include "tensorflow/core/lib/gtl/optional.h"
     20 #include "tensorflow/core/lib/strings/str_util.h"
     21 #include "tensorflow/core/lib/strings/strcat.h"
     22 
     23 namespace xla {
     24 
     25 using tensorflow::gtl::nullopt;
     26 using tensorflow::gtl::optional;
     27 
     28 // Finds and returns the non-constant operand in instr.
     29 //
     30 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
     31 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
     32   const HloInstruction* result = nullptr;
     33   for (const HloInstruction* operand : instr->operands()) {
     34     if (!operand->IsConstant()) {
     35       if (result != nullptr) {
     36         CHECK_EQ(result, operand);
     37       }
     38       result = operand;
     39     }
     40   }
     41   CHECK_NE(result, nullptr);
     42   return result;
     43 }
     44 
     45 // Determines whether the given instruction is a send/recv node, or has a
     46 // subcomputation which contains a send/recv node.
     47 static bool IsOrContainsSendOrRecv(const HloInstruction* instr);
     48 
     49 // Determines whether the given computation contains a send or recv node.
     50 static bool ContainsSendOrRecv(const HloComputation* comp) {
     51   for (const auto* instr : comp->instructions()) {
     52     if (IsOrContainsSendOrRecv(instr)) {
     53       return true;
     54     }
     55   }
     56   return false;
     57 }
     58 
     59 static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
     60   if (instr->opcode() == HloOpcode::kSend ||
     61       instr->opcode() == HloOpcode::kSendDone ||
     62       instr->opcode() == HloOpcode::kRecv ||
     63       instr->opcode() == HloOpcode::kRecvDone) {
     64     return true;
     65   }
     66   for (const auto& subcomp : instr->called_computations()) {
     67     if (ContainsSendOrRecv(subcomp)) {
     68       return true;
     69     }
     70   }
     71   return false;
     72 }
     73 
     74 // If all of instr's operands are either constants or have the form
     75 //   get-tuple-element(gte_operand, N)
     76 // for the same value N, returns N.  Otherwise, returns nullopt.
     77 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
     78                                           const HloInstruction* gte_operand) {
     79   VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
     80           << gte_operand->ToString() << ")";
     81   optional<int64> tuple_idx;
     82   for (const HloInstruction* operand : instr->operands()) {
     83     if (operand->IsConstant()) {
     84       continue;
     85     }
     86     if (operand->opcode() != HloOpcode::kGetTupleElement) {
     87       VLOG(2) << "instr uses something other than gte(gte_operand): "
     88               << operand->ToString();
     89       return nullopt;
     90     }
     91     if (operand->operand(0) != gte_operand) {
     92       VLOG(2) << "instr has gte whose operand is not gte_operand: "
     93               << operand->ToString();
     94       return nullopt;
     95     }
     96     if (tuple_idx && tuple_idx != operand->tuple_index()) {
     97       VLOG(2) << "instr has operands with conflicting gte indices, "
     98               << *tuple_idx << " vs " << operand->tuple_index();
     99       return nullopt;
    100     }
    101 
    102     tuple_idx = operand->tuple_index();
    103   }
    104   return tuple_idx;
    105 }
    106 
    107 // Tries to get the tuple index of the induction variable of a while loop.
    108 //
    109 // Checks that the loop condition and root both plumb the induction variable
    110 // through the same tuple index, and that they both apply exactly one op to the
    111 // induction variable before  deciding whether to do another loop iteration (in
    112 // the loop condition's case) or packing the induction variable into the result
    113 // tuple (in the loop body's case).
    114 //
    115 // Specifically, checks that the loop condition has structure
    116 //
    117 //   root = op(constants, get-tuple-elem(param0, N), constants)
    118 //
    119 // and the loop body has the structure
    120 //
    121 //   inc = op(constants, get-tuple-elem(param0, N), constants)
    122 //   root = tuple(..., inc, ...)  // inc is N'th operand of tuple().
    123 //
    124 // If so, returns N.  Otherwise, returns nullopt.
    125 static optional<int64> GetLoopInductionVarTupleIdx(
    126     const HloInstruction* while_op) {
    127   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
    128   VLOG(2) << "Finding induction variable for loop "
    129           << while_op->ToShortString();
    130 
    131   // The while_cond computation should have the form
    132   //
    133   //   while_cond_root =
    134   //       op(constants, get-tuple-elem(while_cond_param, N), constants).
    135   //
    136   // If it does, set indvar_tuple_idx to N.
    137   auto* while_cond = while_op->while_condition();
    138   auto* while_cond_root = while_cond->root_instruction();
    139   auto* while_cond_param = while_cond->parameter_instruction(0);
    140   optional<int64> indvar_tuple_idx =
    141       GetGTEOperandIndex(while_cond_root, while_cond_param);
    142   if (!indvar_tuple_idx) {
    143     VLOG(2) << "Induction variable not found in loop condition: "
    144             << while_cond->root_instruction()->ToString();
    145     return nullopt;
    146   }
    147 
    148   // The while_body computation should have the form
    149   //
    150   //   while_body_inc =
    151   //       op(constants, get-tuple-elem(while_body_param, N), constants)
    152   //   while_body_root = tuple(..., while_body_inc, ...)
    153   //
    154   // where while_body_inc is operand N of while_body_root.
    155   auto* while_body = while_op->while_body();
    156   auto* while_body_root = while_body->root_instruction();
    157   if (while_body_root->opcode() != HloOpcode::kTuple) {
    158     VLOG(2) << "While body's root is not a tuple instruction: "
    159             << while_body_root->ToString();
    160     return nullopt;
    161   }
    162 
    163   auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
    164   auto* while_body_param = while_body->parameter_instruction(0);
    165   optional<int64> while_body_indvar_tuple_idx =
    166       GetGTEOperandIndex(while_body_inc, while_body_param);
    167   if (!while_body_indvar_tuple_idx) {
    168     VLOG(2)
    169         << "Induction variable not found in while body increment instruction: "
    170         << while_body_inc->ToString();
    171     return nullopt;
    172   }
    173   if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
    174     VLOG(2) << "Tuple index of induction variable does not match between loop "
    175                "condition ("
    176             << *indvar_tuple_idx << ") and while body ("
    177             << *while_body_indvar_tuple_idx << ")";
    178     return nullopt;
    179   }
    180 
    181   // Finally, check that the while loop's initial value is a tuple with enough
    182   // elements.
    183   auto* while_init = while_op->operand(0);
    184   if (while_init->opcode() != HloOpcode::kTuple) {
    185     VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
    186     return nullopt;
    187   }
    188 
    189   VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
    190   return indvar_tuple_idx;
    191 }
    192 
    193 // Tries to determine the number of times the given loop executes.  Currently
    194 // simply returns 0, 1, or "can't tell" (nullopt).
    195 static optional<int64> GetLoopTripCount(HloInstruction* while_op) {
    196   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
    197   VLOG(2) << "Getting trip count for loop " << while_op->ToString();
    198 
    199   // The loop's induction variable is found at
    200   //
    201   //   get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
    202   //
    203   // where comp is while_op->while_body() or while_op->while_condition().
    204   optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
    205   if (!indvar_tuple_idx) {
    206     return nullopt;
    207   }
    208 
    209   VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx
    210           << " in input tuple.";
    211 
    212   // Now that we know the index of the induction variable, we can we can try to
    213   // compute how many times the loop executes.  Start by computing the induction
    214   // variable's initial value.
    215   HloEvaluator evaluator;
    216   auto* while_init = while_op->mutable_operand(0);
    217   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
    218   StatusOr<std::unique_ptr<Literal>> indvar_init_result =
    219       evaluator.Evaluate(indvar_init);
    220   if (!indvar_init_result.ok()) {
    221     VLOG(2) << "Couldn't evaluate induction variable init: "
    222             << indvar_init_result.status();
    223     return nullopt;
    224   }
    225 
    226   // Evaluates the while loop's condition, returning either "true" (continue
    227   // looping), "false" (stop looping), or nullopt (can't evaluate).
    228   auto evaluate_while_cond = [&](const Literal& indvar) -> optional<bool> {
    229     auto* while_cond = while_op->while_condition();
    230     auto* while_cond_root = while_cond->root_instruction();
    231     auto* while_cond_indvar = NonConstantOperand(while_cond_root);
    232     StatusOr<std::unique_ptr<Literal>> result =
    233         evaluator.EvaluateWithSubstitutions(while_cond_root,
    234                                             {{while_cond_indvar, &indvar}});
    235     if (!result.ok()) {
    236       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
    237       return nullopt;
    238     }
    239     return result.ValueOrDie()->data<bool>() ==
    240            tensorflow::gtl::ArraySlice<bool>{true};
    241   };
    242 
    243   // The initial value of the induction variable.
    244   const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie();
    245 
    246   // Evaluate whether the while condition is true when seeded with
    247   // indvar_iter0_val.
    248   optional<bool> while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val);
    249   if (while_cond_iter0_val == false) {
    250     VLOG(2) << "Loop has static trip count of 0.";
    251     return 0;
    252   }
    253 
    254   // Calculate the value of the induction variable after one iteration of the
    255   // loop, and check whether the while condition is true with this new value.
    256   auto* while_body = while_op->while_body();
    257   auto* while_body_indvar_update =
    258       while_body->root_instruction()->operand(*indvar_tuple_idx);
    259   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
    260   StatusOr<std::unique_ptr<Literal>> indvar_iter1_result =
    261       evaluator.EvaluateWithSubstitutions(
    262           while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}});
    263   if (!indvar_iter1_result.ok()) {
    264     VLOG(2) << "Couldn't evaluate induction variable update: "
    265             << indvar_iter1_result.status();
    266     return nullopt;
    267   }
    268   const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie();
    269   optional<bool> while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val);
    270   if (while_cond_iter1_val == false) {
    271     VLOG(2) << "Determined that loop has static trip count of 1.";
    272     return 1;
    273   }
    274 
    275   VLOG(2) << "Loop has unknown trip count >= 1.";
    276   return nullopt;
    277 }
    278 
    279 // Tries to remove elements in a while loop's tuple that aren't used within the
    280 // loop.
    281 //
    282 // Specifically, if a loop is tuple-shaped, and there exists some element of
    283 // that tuple that is not used by the loop condition and is not used by the loop
    284 // body except to pass it to the next iteration of the loop, then we can remove
    285 // that element from the loop's tuples.
    286 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
    287   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
    288 
    289   // Don't try this transformation if the while loop isn't removable, since if
    290   // it succeeds ultimately we're going to have to replace the old while loop
    291   // with a new one.
    292   if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
    293     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
    294     return false;
    295   }
    296 
    297   HloModule* module = while_op->GetModule();
    298   HloComputation* computation = while_op->parent();
    299   HloInstruction* while_init = while_op->mutable_operand(0);
    300   HloComputation* while_cond = while_op->while_condition();
    301   HloComputation* while_body = while_op->while_body();
    302   HloInstruction* while_body_root = while_body->root_instruction();
    303 
    304   if (!ShapeUtil::IsTuple(while_init->shape())) {
    305     VLOG(2) << "While op's carried value isn't tuple shaped.";
    306     return false;
    307   }
    308 
    309   if (while_body_root->opcode() != HloOpcode::kTuple) {
    310     VLOG(2) << "While body's root is not a tuple(...) instruction.";
    311     return false;
    312   }
    313 
    314   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
    315 
    316   // Bail if param0 of while_cond or while_body has users which aren't of type
    317   // get-tuple-element.
    318   for (const HloInstruction* instr : {while_body->parameter_instruction(0),
    319                                       while_cond->parameter_instruction(0)}) {
    320     for (const HloInstruction* user : instr->users()) {
    321       if (user->opcode() != HloOpcode::kGetTupleElement) {
    322         VLOG(2) << "Cowardly refusing to analyze while loop with "
    323                 << instr->ToString(print_no_metadata)
    324                 << " used by non-GTE instruction "
    325                 << user->ToString(print_no_metadata) << " in computation "
    326                 << instr->parent()->name();
    327         return false;
    328       }
    329     }
    330   }
    331 
    332   const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
    333   if (tuple_size == 0) {
    334     VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
    335                "empty.";
    336     return false;
    337   }
    338 
    339   tensorflow::gtl::FlatSet<int64> used_tuple_indices;
    340   for (HloComputation* comp : {while_body, while_cond}) {
    341     // The HLO verifier ensures that while_input's shape matches while_init's
    342     // shape, which we verified above is a tuple.
    343     HloInstruction* while_input = comp->parameter_instruction(0);
    344 
    345     for (const HloInstruction* user : while_input->users()) {
    346       // This user doesn't count if it's only used by the while body's root, and
    347       // the root places the tuple element into the same index of the tuple as
    348       // it came from.  That just amounts to us carrying the variable through
    349       // the loop.
    350       //
    351       // Careful: HloInstruction::operand_index returns the first index the
    352       // operand appears in, but it may appear more than once!
    353       if (user->user_count() == 1 && user->users().front() == while_body_root &&
    354           while_body_root->operand_index(user) == user->tuple_index() &&
    355           std::count(while_body_root->operands().begin(),
    356                      while_body_root->operands().end(), user) == 1) {
    357         continue;
    358       }
    359 
    360       used_tuple_indices.insert(user->tuple_index());
    361       if (used_tuple_indices.size() == tuple_size) {
    362         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
    363                 << " uses all of its inputs; no simplification possible.";
    364         return false;
    365       }
    366     }
    367   }
    368 
    369   // If a tuple element is not passed unmodified from the while body's param0
    370   // through to the while body's root, count that element as "used", since
    371   // removing that element would be observable.
    372   for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
    373     if (used_tuple_indices.count(i)) {
    374       continue;
    375     }
    376 
    377     auto* operand = while_body_root->operand(i);
    378     if (operand->opcode() != HloOpcode::kGetTupleElement ||
    379         operand->operand(0) != while_body->parameter_instruction(0) ||
    380         operand->tuple_index() != i) {
    381       VLOG(2) << "Tuple index " << i
    382               << " is not passed through loop body unmodified.";
    383       used_tuple_indices.insert(i);
    384 
    385       if (used_tuple_indices.size() == tuple_size) {
    386         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
    387                 << " uses all of its inputs; no simplification possible.";
    388         return false;
    389       }
    390     }
    391   }
    392 
    393   // If we got here, used_tuple_indices.size() < tuple_size, meaning some
    394   // elements of the loop's tuple aren't used by while_body or while_cond.
    395   CHECK_LT(used_tuple_indices.size(), tuple_size);
    396 
    397   VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
    398           << " elements from tuple of "
    399           << while_op->ToString(print_no_metadata);
    400 
    401   // Build up maps from the old/new to the new/old tuple indices.
    402   std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
    403                                           used_tuple_indices.end());
    404   std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
    405 
    406   tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx;
    407   for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
    408     int64 old_idx = new_to_old_tuple_idx[new_idx];
    409     old_to_new_tuple_idx[old_idx] = new_idx;
    410     VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
    411   }
    412 
    413   // Compute the shape of the while op after we remove the dead indices.
    414   std::vector<Shape> new_while_tuple_elem_shapes;
    415   new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
    416   for (int64 old_idx : new_to_old_tuple_idx) {
    417     new_while_tuple_elem_shapes.push_back(
    418         while_init->shape().tuple_shapes(old_idx));
    419   }
    420   Shape new_while_shape =
    421       ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
    422 
    423   // Returns a map from elements in the computation to new instructions which
    424   // replace the old instructions after we remove unused elements from the while
    425   // tuple.
    426   auto make_while_computation_replacements = [&](const HloComputation* comp) {
    427     std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
    428         replacements;
    429 
    430     auto* param = comp->parameter_instruction(0);
    431     replacements.emplace(param, HloInstruction::CreateParameter(
    432                                     0, new_while_shape, param->name()));
    433 
    434     // Materialize param's users, since we're about to add new ones below.
    435     std::vector<HloInstruction*> materialized_users(param->users().begin(),
    436                                                     param->users().end());
    437     for (const auto* user : materialized_users) {
    438       // The while body root is handled separately.
    439       if (user == while_body_root) {
    440         continue;
    441       }
    442       CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
    443           << user->ToString(print_no_metadata);
    444 
    445       int64 old_idx = user->tuple_index();
    446       auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
    447       if (new_idx_iter != old_to_new_tuple_idx.end()) {
    448         // This is a GTE of an index that survives.  Replace it.
    449         replacements.emplace(
    450             user, HloInstruction::CreateGetTupleElement(user->shape(), param,
    451                                                         new_idx_iter->second));
    452       } else {
    453         // This is a GTE of an index that we've removed.  Remove it from the
    454         // cloned computation.
    455         CHECK(user->user_count() == 0 ||
    456               user->user_count() == 1 &&
    457                   user->users().front() == while_body_root)
    458             << "Instruction " << user->ToString(print_no_metadata)
    459             << " should be unused (except by root of while body), but has "
    460                "users: {"
    461             << tensorflow::str_util::Join(
    462                    user->users(), ", ",
    463                    [&](string* out, const HloInstruction* instr) {
    464                      tensorflow::strings::StrAppend(
    465                          out, instr->ToString(print_no_metadata));
    466                    })
    467             << "}";
    468 
    469         replacements.emplace(user, nullptr);
    470       }
    471     }
    472     return replacements;
    473   };
    474 
    475   // Create the new while condition, body, and init value.
    476   std::unique_ptr<HloComputation> new_while_cond =
    477       while_cond->CloneWithReplacements(
    478           make_while_computation_replacements(while_cond));
    479 
    480   std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
    481       while_body_replacements = make_while_computation_replacements(while_body);
    482   std::vector<HloInstruction*> new_while_body_root_elems;
    483   new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
    484   for (int64 old_idx : new_to_old_tuple_idx) {
    485     new_while_body_root_elems.push_back(
    486         while_body_root->mutable_operand(old_idx));
    487   }
    488   while_body_replacements.emplace(
    489       while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
    490   std::unique_ptr<HloComputation> new_while_body =
    491       while_body->CloneWithReplacements(std::move(while_body_replacements));
    492 
    493   // Add a new while_init instruction that repackages the old while_init
    494   // instruction's elements.  We rely on the AlgebraicSimplifier and DCE to
    495   // clean this up in the common case where while_init is a tuple op.  (It's
    496   // definitely tuple-shaped, but it's not necessarily a tuple op.)
    497   std::vector<HloInstruction*> new_while_init_elems;
    498   new_while_init_elems.reserve(new_to_old_tuple_idx.size());
    499   for (int64 old_idx : new_to_old_tuple_idx) {
    500     new_while_init_elems.push_back(
    501         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
    502             while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
    503   }
    504   auto* new_while_init = computation->AddInstruction(
    505       HloInstruction::CreateTuple(new_while_init_elems));
    506 
    507   // Create the new while op.
    508   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
    509       new_while_shape,
    510       module->AddEmbeddedComputation(std::move(new_while_cond)),
    511       module->AddEmbeddedComputation(std::move(new_while_body)),
    512       new_while_init));
    513 
    514   // Create a tuple op that recreates the output of the old while op.  That is,
    515   // we transform to
    516   //
    517   //  new_while_init   while_init
    518   //       |              |
    519   //       V              |
    520   //   new_while          |
    521   //       |              |
    522   //       -------|   |----
    523   //              V   V
    524   //            new_tuple
    525   //                |
    526   //                V
    527   //    (orig. users of while op)
    528   //
    529   // The tuple simplifier will then simplify this if possible, removing
    530   // new_tuple and while_init.
    531   std::vector<HloInstruction*> new_tuple_elems;
    532   for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
    533     auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
    534     if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
    535       int64 gte_idx = new_tuple_idx_it->second;
    536       new_tuple_elems.push_back(
    537           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
    538               new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
    539               gte_idx)));
    540     } else {
    541       new_tuple_elems.push_back(
    542           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
    543               while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
    544     }
    545   }
    546   HloInstruction* new_tuple =
    547       computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
    548   TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple));
    549 
    550   return true;
    551 }
    552 
    553 // Tries to remove a while loop from the graph.
    554 //
    555 //  - Loops with trip count of 0 can be replaced by the loop's "init" value.
    556 //  - Loops with trip count of 1 can be replaced by the loop's body, with the
    557 //    loop itself removed.
    558 //
    559 // Returns true if it made a change to the graph.
    560 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
    561   // Cowardly refuse to remove loops that are not removable.  In practice,
    562   // this means that we can't remove loops that contain side-effecting
    563   // instructions or have control predecessors/successors.
    564   //
    565   // This is not a fundamental limitation.  The control operands can be moved
    566   // onto the new HLOs after simplification, and any side-effecting ops inside
    567   // the loop aren't removed, just cloned and added back to the loop.  But
    568   // moving an op out of the loop also removes implicit control dependencies
    569   // between the op and the ops outside the loop, so we'd have to add those back
    570   // for things like infeed/outfeed.  It gets complicated.  So for now we just
    571   // avoid it.
    572   if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
    573     VLOG(2) << "Not attempting to remove while loop it is not removable: "
    574             << while_op->ToShortString();
    575     return false;
    576   }
    577 
    578   // Remove while loops with static trip count of 0.
    579   optional<int64> trip_count = GetLoopTripCount(while_op);
    580   if (trip_count && *trip_count == 0) {
    581     // The loop never executes, so the value of the loop is the value of its
    582     // "init" operand.
    583     auto computation = while_op->parent();
    584 
    585     // Remove while_op (i.e., call ReplaceInstruction rather than
    586     // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
    587     // a loop without an intervening DCE, we don't try to re-remove the loop.
    588     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
    589         while_op, while_op->mutable_operand(0)));
    590     return true;
    591   }
    592 
    593   // Transform while loops with static trip count of 1 into a call op, then
    594   // inline the call.
    595   if (trip_count && *trip_count == 1) {
    596     auto computation = while_op->parent();
    597     auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
    598         while_op->shape(), while_op->operands(), while_op->while_body()));
    599     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
    600     TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
    601                         CallInliner::Inline(call_op));
    602     (void)inlined_instructions_map;
    603     return true;
    604   }
    605   return false;
    606 }
    607 
    608 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
    609   XLA_VLOG_LINES(3,
    610                  "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
    611   bool changed = false;
    612 
    613   // Gather all the while ops in our module.  We do this ahead of time so we
    614   // don't have to worry about mutating the lists of computations or
    615   // instructions while we iterate.
    616   std::vector<HloInstruction*> while_ops;
    617   for (auto* comp : module->computations()) {
    618     for (auto* instr : comp->instructions()) {
    619       if (instr->opcode() == HloOpcode::kWhile) {
    620         while_ops.push_back(instr);
    621       }
    622     }
    623   }
    624 
    625   for (HloInstruction* while_op : while_ops) {
    626     // We can't remove while loops that contain send/recv nodes, because we rely
    627     // on the particular loop structure around the node matching on the send and
    628     // recv sides.  Removing dead while params requires us to remove the loop
    629     // and replace it with a new one, so we can't do that either.
    630     if (ContainsSendOrRecv(while_op->while_body()) ||
    631         ContainsSendOrRecv(while_op->while_condition())) {
    632       VLOG(2) << "Not attempting to simplify while loop because it contains a "
    633                  "send/recv node: "
    634               << while_op->ToShortString();
    635       continue;
    636     }
    637 
    638     StatusOr<bool> result = TryRemoveWhileLoop(while_op);
    639     TF_RETURN_IF_ERROR(result.status());
    640     if (result.ValueOrDie()) {
    641       changed = true;
    642       // Don't try to remove dead while params after successfully removing the
    643       // while loop -- that would result in use-after-free nastiness.
    644       continue;
    645     }
    646 
    647     result = TryRemoveDeadWhileParams(while_op);
    648     TF_RETURN_IF_ERROR(result.status());
    649     changed |= result.ValueOrDie();
    650   }
    651 
    652   XLA_VLOG_LINES(3,
    653                  "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
    654   return changed;
    655 }
    656 
    657 }  // namespace xla
    658