1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 17 #include "tensorflow/compiler/xla/service/call_inliner.h" 18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 19 #include "tensorflow/core/lib/gtl/optional.h" 20 #include "tensorflow/core/lib/strings/str_util.h" 21 #include "tensorflow/core/lib/strings/strcat.h" 22 23 namespace xla { 24 25 using tensorflow::gtl::nullopt; 26 using tensorflow::gtl::optional; 27 28 // Finds and returns the non-constant operand in instr. 29 // 30 // CHECK-fails if instr doesn't have exactly one unique non-constant operand. 31 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { 32 const HloInstruction* result = nullptr; 33 for (const HloInstruction* operand : instr->operands()) { 34 if (!operand->IsConstant()) { 35 if (result != nullptr) { 36 CHECK_EQ(result, operand); 37 } 38 result = operand; 39 } 40 } 41 CHECK_NE(result, nullptr); 42 return result; 43 } 44 45 // Determines whether the given instruction is a send/recv node, or has a 46 // subcomputation which contains a send/recv node. 47 static bool IsOrContainsSendOrRecv(const HloInstruction* instr); 48 49 // Determines whether the given computation contains a send or recv node. 50 static bool ContainsSendOrRecv(const HloComputation* comp) { 51 for (const auto* instr : comp->instructions()) { 52 if (IsOrContainsSendOrRecv(instr)) { 53 return true; 54 } 55 } 56 return false; 57 } 58 59 static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { 60 if (instr->opcode() == HloOpcode::kSend || 61 instr->opcode() == HloOpcode::kSendDone || 62 instr->opcode() == HloOpcode::kRecv || 63 instr->opcode() == HloOpcode::kRecvDone) { 64 return true; 65 } 66 for (const auto& subcomp : instr->called_computations()) { 67 if (ContainsSendOrRecv(subcomp)) { 68 return true; 69 } 70 } 71 return false; 72 } 73 74 // If all of instr's operands are either constants or have the form 75 // get-tuple-element(gte_operand, N) 76 // for the same value N, returns N. Otherwise, returns nullopt. 77 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr, 78 const HloInstruction* gte_operand) { 79 VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " 80 << gte_operand->ToString() << ")"; 81 optional<int64> tuple_idx; 82 for (const HloInstruction* operand : instr->operands()) { 83 if (operand->IsConstant()) { 84 continue; 85 } 86 if (operand->opcode() != HloOpcode::kGetTupleElement) { 87 VLOG(2) << "instr uses something other than gte(gte_operand): " 88 << operand->ToString(); 89 return nullopt; 90 } 91 if (operand->operand(0) != gte_operand) { 92 VLOG(2) << "instr has gte whose operand is not gte_operand: " 93 << operand->ToString(); 94 return nullopt; 95 } 96 if (tuple_idx && tuple_idx != operand->tuple_index()) { 97 VLOG(2) << "instr has operands with conflicting gte indices, " 98 << *tuple_idx << " vs " << operand->tuple_index(); 99 return nullopt; 100 } 101 102 tuple_idx = operand->tuple_index(); 103 } 104 return tuple_idx; 105 } 106 107 // Tries to get the tuple index of the induction variable of a while loop. 108 // 109 // Checks that the loop condition and root both plumb the induction variable 110 // through the same tuple index, and that they both apply exactly one op to the 111 // induction variable before deciding whether to do another loop iteration (in 112 // the loop condition's case) or packing the induction variable into the result 113 // tuple (in the loop body's case). 114 // 115 // Specifically, checks that the loop condition has structure 116 // 117 // root = op(constants, get-tuple-elem(param0, N), constants) 118 // 119 // and the loop body has the structure 120 // 121 // inc = op(constants, get-tuple-elem(param0, N), constants) 122 // root = tuple(..., inc, ...) // inc is N'th operand of tuple(). 123 // 124 // If so, returns N. Otherwise, returns nullopt. 125 static optional<int64> GetLoopInductionVarTupleIdx( 126 const HloInstruction* while_op) { 127 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); 128 VLOG(2) << "Finding induction variable for loop " 129 << while_op->ToShortString(); 130 131 // The while_cond computation should have the form 132 // 133 // while_cond_root = 134 // op(constants, get-tuple-elem(while_cond_param, N), constants). 135 // 136 // If it does, set indvar_tuple_idx to N. 137 auto* while_cond = while_op->while_condition(); 138 auto* while_cond_root = while_cond->root_instruction(); 139 auto* while_cond_param = while_cond->parameter_instruction(0); 140 optional<int64> indvar_tuple_idx = 141 GetGTEOperandIndex(while_cond_root, while_cond_param); 142 if (!indvar_tuple_idx) { 143 VLOG(2) << "Induction variable not found in loop condition: " 144 << while_cond->root_instruction()->ToString(); 145 return nullopt; 146 } 147 148 // The while_body computation should have the form 149 // 150 // while_body_inc = 151 // op(constants, get-tuple-elem(while_body_param, N), constants) 152 // while_body_root = tuple(..., while_body_inc, ...) 153 // 154 // where while_body_inc is operand N of while_body_root. 155 auto* while_body = while_op->while_body(); 156 auto* while_body_root = while_body->root_instruction(); 157 if (while_body_root->opcode() != HloOpcode::kTuple) { 158 VLOG(2) << "While body's root is not a tuple instruction: " 159 << while_body_root->ToString(); 160 return nullopt; 161 } 162 163 auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); 164 auto* while_body_param = while_body->parameter_instruction(0); 165 optional<int64> while_body_indvar_tuple_idx = 166 GetGTEOperandIndex(while_body_inc, while_body_param); 167 if (!while_body_indvar_tuple_idx) { 168 VLOG(2) 169 << "Induction variable not found in while body increment instruction: " 170 << while_body_inc->ToString(); 171 return nullopt; 172 } 173 if (while_body_indvar_tuple_idx != indvar_tuple_idx) { 174 VLOG(2) << "Tuple index of induction variable does not match between loop " 175 "condition (" 176 << *indvar_tuple_idx << ") and while body (" 177 << *while_body_indvar_tuple_idx << ")"; 178 return nullopt; 179 } 180 181 // Finally, check that the while loop's initial value is a tuple with enough 182 // elements. 183 auto* while_init = while_op->operand(0); 184 if (while_init->opcode() != HloOpcode::kTuple) { 185 VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); 186 return nullopt; 187 } 188 189 VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; 190 return indvar_tuple_idx; 191 } 192 193 // Tries to determine the number of times the given loop executes. Currently 194 // simply returns 0, 1, or "can't tell" (nullopt). 195 static optional<int64> GetLoopTripCount(HloInstruction* while_op) { 196 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); 197 VLOG(2) << "Getting trip count for loop " << while_op->ToString(); 198 199 // The loop's induction variable is found at 200 // 201 // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), 202 // 203 // where comp is while_op->while_body() or while_op->while_condition(). 204 optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); 205 if (!indvar_tuple_idx) { 206 return nullopt; 207 } 208 209 VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx 210 << " in input tuple."; 211 212 // Now that we know the index of the induction variable, we can we can try to 213 // compute how many times the loop executes. Start by computing the induction 214 // variable's initial value. 215 HloEvaluator evaluator; 216 auto* while_init = while_op->mutable_operand(0); 217 auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); 218 StatusOr<std::unique_ptr<Literal>> indvar_init_result = 219 evaluator.Evaluate(indvar_init); 220 if (!indvar_init_result.ok()) { 221 VLOG(2) << "Couldn't evaluate induction variable init: " 222 << indvar_init_result.status(); 223 return nullopt; 224 } 225 226 // Evaluates the while loop's condition, returning either "true" (continue 227 // looping), "false" (stop looping), or nullopt (can't evaluate). 228 auto evaluate_while_cond = [&](const Literal& indvar) -> optional<bool> { 229 auto* while_cond = while_op->while_condition(); 230 auto* while_cond_root = while_cond->root_instruction(); 231 auto* while_cond_indvar = NonConstantOperand(while_cond_root); 232 StatusOr<std::unique_ptr<Literal>> result = 233 evaluator.EvaluateWithSubstitutions(while_cond_root, 234 {{while_cond_indvar, &indvar}}); 235 if (!result.ok()) { 236 VLOG(2) << "Couldn't evaluate while cond: " << result.status(); 237 return nullopt; 238 } 239 return result.ValueOrDie()->data<bool>() == 240 tensorflow::gtl::ArraySlice<bool>{true}; 241 }; 242 243 // The initial value of the induction variable. 244 const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); 245 246 // Evaluate whether the while condition is true when seeded with 247 // indvar_iter0_val. 248 optional<bool> while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); 249 if (while_cond_iter0_val == false) { 250 VLOG(2) << "Loop has static trip count of 0."; 251 return 0; 252 } 253 254 // Calculate the value of the induction variable after one iteration of the 255 // loop, and check whether the while condition is true with this new value. 256 auto* while_body = while_op->while_body(); 257 auto* while_body_indvar_update = 258 while_body->root_instruction()->operand(*indvar_tuple_idx); 259 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); 260 StatusOr<std::unique_ptr<Literal>> indvar_iter1_result = 261 evaluator.EvaluateWithSubstitutions( 262 while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); 263 if (!indvar_iter1_result.ok()) { 264 VLOG(2) << "Couldn't evaluate induction variable update: " 265 << indvar_iter1_result.status(); 266 return nullopt; 267 } 268 const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); 269 optional<bool> while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); 270 if (while_cond_iter1_val == false) { 271 VLOG(2) << "Determined that loop has static trip count of 1."; 272 return 1; 273 } 274 275 VLOG(2) << "Loop has unknown trip count >= 1."; 276 return nullopt; 277 } 278 279 // Tries to remove elements in a while loop's tuple that aren't used within the 280 // loop. 281 // 282 // Specifically, if a loop is tuple-shaped, and there exists some element of 283 // that tuple that is not used by the loop condition and is not used by the loop 284 // body except to pass it to the next iteration of the loop, then we can remove 285 // that element from the loop's tuples. 286 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { 287 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); 288 289 // Don't try this transformation if the while loop isn't removable, since if 290 // it succeeds ultimately we're going to have to replace the old while loop 291 // with a new one. 292 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { 293 VLOG(2) << "Can't remove dead parameters from non-removable while op."; 294 return false; 295 } 296 297 HloModule* module = while_op->GetModule(); 298 HloComputation* computation = while_op->parent(); 299 HloInstruction* while_init = while_op->mutable_operand(0); 300 HloComputation* while_cond = while_op->while_condition(); 301 HloComputation* while_body = while_op->while_body(); 302 HloInstruction* while_body_root = while_body->root_instruction(); 303 304 if (!ShapeUtil::IsTuple(while_init->shape())) { 305 VLOG(2) << "While op's carried value isn't tuple shaped."; 306 return false; 307 } 308 309 if (while_body_root->opcode() != HloOpcode::kTuple) { 310 VLOG(2) << "While body's root is not a tuple(...) instruction."; 311 return false; 312 } 313 314 auto print_no_metadata = HloPrintOptions().set_print_metadata(false); 315 316 // Bail if param0 of while_cond or while_body has users which aren't of type 317 // get-tuple-element. 318 for (const HloInstruction* instr : {while_body->parameter_instruction(0), 319 while_cond->parameter_instruction(0)}) { 320 for (const HloInstruction* user : instr->users()) { 321 if (user->opcode() != HloOpcode::kGetTupleElement) { 322 VLOG(2) << "Cowardly refusing to analyze while loop with " 323 << instr->ToString(print_no_metadata) 324 << " used by non-GTE instruction " 325 << user->ToString(print_no_metadata) << " in computation " 326 << instr->parent()->name(); 327 return false; 328 } 329 } 330 } 331 332 const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); 333 if (tuple_size == 0) { 334 VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " 335 "empty."; 336 return false; 337 } 338 339 tensorflow::gtl::FlatSet<int64> used_tuple_indices; 340 for (HloComputation* comp : {while_body, while_cond}) { 341 // The HLO verifier ensures that while_input's shape matches while_init's 342 // shape, which we verified above is a tuple. 343 HloInstruction* while_input = comp->parameter_instruction(0); 344 345 for (const HloInstruction* user : while_input->users()) { 346 // This user doesn't count if it's only used by the while body's root, and 347 // the root places the tuple element into the same index of the tuple as 348 // it came from. That just amounts to us carrying the variable through 349 // the loop. 350 // 351 // Careful: HloInstruction::operand_index returns the first index the 352 // operand appears in, but it may appear more than once! 353 if (user->user_count() == 1 && user->users().front() == while_body_root && 354 while_body_root->operand_index(user) == user->tuple_index() && 355 std::count(while_body_root->operands().begin(), 356 while_body_root->operands().end(), user) == 1) { 357 continue; 358 } 359 360 used_tuple_indices.insert(user->tuple_index()); 361 if (used_tuple_indices.size() == tuple_size) { 362 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) 363 << " uses all of its inputs; no simplification possible."; 364 return false; 365 } 366 } 367 } 368 369 // If a tuple element is not passed unmodified from the while body's param0 370 // through to the while body's root, count that element as "used", since 371 // removing that element would be observable. 372 for (int64 i = 0; i < while_body_root->operand_count(); ++i) { 373 if (used_tuple_indices.count(i)) { 374 continue; 375 } 376 377 auto* operand = while_body_root->operand(i); 378 if (operand->opcode() != HloOpcode::kGetTupleElement || 379 operand->operand(0) != while_body->parameter_instruction(0) || 380 operand->tuple_index() != i) { 381 VLOG(2) << "Tuple index " << i 382 << " is not passed through loop body unmodified."; 383 used_tuple_indices.insert(i); 384 385 if (used_tuple_indices.size() == tuple_size) { 386 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) 387 << " uses all of its inputs; no simplification possible."; 388 return false; 389 } 390 } 391 } 392 393 // If we got here, used_tuple_indices.size() < tuple_size, meaning some 394 // elements of the loop's tuple aren't used by while_body or while_cond. 395 CHECK_LT(used_tuple_indices.size(), tuple_size); 396 397 VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() 398 << " elements from tuple of " 399 << while_op->ToString(print_no_metadata); 400 401 // Build up maps from the old/new to the new/old tuple indices. 402 std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(), 403 used_tuple_indices.end()); 404 std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); 405 406 tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx; 407 for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { 408 int64 old_idx = new_to_old_tuple_idx[new_idx]; 409 old_to_new_tuple_idx[old_idx] = new_idx; 410 VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx; 411 } 412 413 // Compute the shape of the while op after we remove the dead indices. 414 std::vector<Shape> new_while_tuple_elem_shapes; 415 new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size()); 416 for (int64 old_idx : new_to_old_tuple_idx) { 417 new_while_tuple_elem_shapes.push_back( 418 while_init->shape().tuple_shapes(old_idx)); 419 } 420 Shape new_while_shape = 421 ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes); 422 423 // Returns a map from elements in the computation to new instructions which 424 // replace the old instructions after we remove unused elements from the while 425 // tuple. 426 auto make_while_computation_replacements = [&](const HloComputation* comp) { 427 std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 428 replacements; 429 430 auto* param = comp->parameter_instruction(0); 431 replacements.emplace(param, HloInstruction::CreateParameter( 432 0, new_while_shape, param->name())); 433 434 // Materialize param's users, since we're about to add new ones below. 435 std::vector<HloInstruction*> materialized_users(param->users().begin(), 436 param->users().end()); 437 for (const auto* user : materialized_users) { 438 // The while body root is handled separately. 439 if (user == while_body_root) { 440 continue; 441 } 442 CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) 443 << user->ToString(print_no_metadata); 444 445 int64 old_idx = user->tuple_index(); 446 auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); 447 if (new_idx_iter != old_to_new_tuple_idx.end()) { 448 // This is a GTE of an index that survives. Replace it. 449 replacements.emplace( 450 user, HloInstruction::CreateGetTupleElement(user->shape(), param, 451 new_idx_iter->second)); 452 } else { 453 // This is a GTE of an index that we've removed. Remove it from the 454 // cloned computation. 455 CHECK(user->user_count() == 0 || 456 user->user_count() == 1 && 457 user->users().front() == while_body_root) 458 << "Instruction " << user->ToString(print_no_metadata) 459 << " should be unused (except by root of while body), but has " 460 "users: {" 461 << tensorflow::str_util::Join( 462 user->users(), ", ", 463 [&](string* out, const HloInstruction* instr) { 464 tensorflow::strings::StrAppend( 465 out, instr->ToString(print_no_metadata)); 466 }) 467 << "}"; 468 469 replacements.emplace(user, nullptr); 470 } 471 } 472 return replacements; 473 }; 474 475 // Create the new while condition, body, and init value. 476 std::unique_ptr<HloComputation> new_while_cond = 477 while_cond->CloneWithReplacements( 478 make_while_computation_replacements(while_cond)); 479 480 std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 481 while_body_replacements = make_while_computation_replacements(while_body); 482 std::vector<HloInstruction*> new_while_body_root_elems; 483 new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); 484 for (int64 old_idx : new_to_old_tuple_idx) { 485 new_while_body_root_elems.push_back( 486 while_body_root->mutable_operand(old_idx)); 487 } 488 while_body_replacements.emplace( 489 while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); 490 std::unique_ptr<HloComputation> new_while_body = 491 while_body->CloneWithReplacements(std::move(while_body_replacements)); 492 493 // Add a new while_init instruction that repackages the old while_init 494 // instruction's elements. We rely on the AlgebraicSimplifier and DCE to 495 // clean this up in the common case where while_init is a tuple op. (It's 496 // definitely tuple-shaped, but it's not necessarily a tuple op.) 497 std::vector<HloInstruction*> new_while_init_elems; 498 new_while_init_elems.reserve(new_to_old_tuple_idx.size()); 499 for (int64 old_idx : new_to_old_tuple_idx) { 500 new_while_init_elems.push_back( 501 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 502 while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); 503 } 504 auto* new_while_init = computation->AddInstruction( 505 HloInstruction::CreateTuple(new_while_init_elems)); 506 507 // Create the new while op. 508 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile( 509 new_while_shape, 510 module->AddEmbeddedComputation(std::move(new_while_cond)), 511 module->AddEmbeddedComputation(std::move(new_while_body)), 512 new_while_init)); 513 514 // Create a tuple op that recreates the output of the old while op. That is, 515 // we transform to 516 // 517 // new_while_init while_init 518 // | | 519 // V | 520 // new_while | 521 // | | 522 // -------| |---- 523 // V V 524 // new_tuple 525 // | 526 // V 527 // (orig. users of while op) 528 // 529 // The tuple simplifier will then simplify this if possible, removing 530 // new_tuple and while_init. 531 std::vector<HloInstruction*> new_tuple_elems; 532 for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { 533 auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); 534 if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { 535 int64 gte_idx = new_tuple_idx_it->second; 536 new_tuple_elems.push_back( 537 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 538 new_while_op->shape().tuple_shapes(gte_idx), new_while_op, 539 gte_idx))); 540 } else { 541 new_tuple_elems.push_back( 542 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 543 while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); 544 } 545 } 546 HloInstruction* new_tuple = 547 computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); 548 TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple)); 549 550 return true; 551 } 552 553 // Tries to remove a while loop from the graph. 554 // 555 // - Loops with trip count of 0 can be replaced by the loop's "init" value. 556 // - Loops with trip count of 1 can be replaced by the loop's body, with the 557 // loop itself removed. 558 // 559 // Returns true if it made a change to the graph. 560 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) { 561 // Cowardly refuse to remove loops that are not removable. In practice, 562 // this means that we can't remove loops that contain side-effecting 563 // instructions or have control predecessors/successors. 564 // 565 // This is not a fundamental limitation. The control operands can be moved 566 // onto the new HLOs after simplification, and any side-effecting ops inside 567 // the loop aren't removed, just cloned and added back to the loop. But 568 // moving an op out of the loop also removes implicit control dependencies 569 // between the op and the ops outside the loop, so we'd have to add those back 570 // for things like infeed/outfeed. It gets complicated. So for now we just 571 // avoid it. 572 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { 573 VLOG(2) << "Not attempting to remove while loop it is not removable: " 574 << while_op->ToShortString(); 575 return false; 576 } 577 578 // Remove while loops with static trip count of 0. 579 optional<int64> trip_count = GetLoopTripCount(while_op); 580 if (trip_count && *trip_count == 0) { 581 // The loop never executes, so the value of the loop is the value of its 582 // "init" operand. 583 auto computation = while_op->parent(); 584 585 // Remove while_op (i.e., call ReplaceInstruction rather than 586 // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in 587 // a loop without an intervening DCE, we don't try to re-remove the loop. 588 TF_RETURN_IF_ERROR(computation->ReplaceInstruction( 589 while_op, while_op->mutable_operand(0))); 590 return true; 591 } 592 593 // Transform while loops with static trip count of 1 into a call op, then 594 // inline the call. 595 if (trip_count && *trip_count == 1) { 596 auto computation = while_op->parent(); 597 auto call_op = computation->AddInstruction(HloInstruction::CreateCall( 598 while_op->shape(), while_op->operands(), while_op->while_body())); 599 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); 600 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, 601 CallInliner::Inline(call_op)); 602 (void)inlined_instructions_map; 603 return true; 604 } 605 return false; 606 } 607 608 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) { 609 XLA_VLOG_LINES(3, 610 "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); 611 bool changed = false; 612 613 // Gather all the while ops in our module. We do this ahead of time so we 614 // don't have to worry about mutating the lists of computations or 615 // instructions while we iterate. 616 std::vector<HloInstruction*> while_ops; 617 for (auto* comp : module->computations()) { 618 for (auto* instr : comp->instructions()) { 619 if (instr->opcode() == HloOpcode::kWhile) { 620 while_ops.push_back(instr); 621 } 622 } 623 } 624 625 for (HloInstruction* while_op : while_ops) { 626 // We can't remove while loops that contain send/recv nodes, because we rely 627 // on the particular loop structure around the node matching on the send and 628 // recv sides. Removing dead while params requires us to remove the loop 629 // and replace it with a new one, so we can't do that either. 630 if (ContainsSendOrRecv(while_op->while_body()) || 631 ContainsSendOrRecv(while_op->while_condition())) { 632 VLOG(2) << "Not attempting to simplify while loop because it contains a " 633 "send/recv node: " 634 << while_op->ToShortString(); 635 continue; 636 } 637 638 StatusOr<bool> result = TryRemoveWhileLoop(while_op); 639 TF_RETURN_IF_ERROR(result.status()); 640 if (result.ValueOrDie()) { 641 changed = true; 642 // Don't try to remove dead while params after successfully removing the 643 // while loop -- that would result in use-after-free nastiness. 644 continue; 645 } 646 647 result = TryRemoveDeadWhileParams(while_op); 648 TF_RETURN_IF_ERROR(result.status()); 649 changed |= result.ValueOrDie(); 650 } 651 652 XLA_VLOG_LINES(3, 653 "WhileLoopSimplifier::Run(), after:\n" + module->ToString()); 654 return changed; 655 } 656 657 } // namespace xla 658