1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 17 #include "absl/container/flat_hash_map.h" 18 #include "absl/container/flat_hash_set.h" 19 #include "absl/strings/str_cat.h" 20 #include "absl/strings/str_join.h" 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/xla/primitive_util.h" 23 #include "tensorflow/compiler/xla/service/call_inliner.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 26 #include "tensorflow/compiler/xla/service/hlo_query.h" 27 #include "tensorflow/compiler/xla/service/pattern_matcher.h" 28 #include "tensorflow/compiler/xla/service/while_loop_analysis.h" 29 30 namespace xla { 31 32 namespace m = match; 33 using absl::optional; 34 using hlo_query::ContainsInstrWithOpcode; 35 36 // Tries to remove elements in a while loop's tuple that aren't used within the 37 // loop. 38 // 39 // Specifically, if a loop is tuple-shaped, and there exists some element of 40 // that tuple that is not used by the loop condition and is not used by the loop 41 // body except to pass it to the next iteration of the loop, then we can remove 42 // that element from the loop's tuples. 43 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { 44 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); 45 46 // Don't try this transformation if the while loop isn't removable, since if 47 // it succeeds ultimately we're going to have to replace the old while loop 48 // with a new one. 49 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { 50 VLOG(2) << "Can't remove dead parameters from non-removable while op."; 51 return false; 52 } 53 54 HloModule* module = while_op->GetModule(); 55 HloComputation* computation = while_op->parent(); 56 HloInstruction* while_init = while_op->mutable_operand(0); 57 HloComputation* while_cond = while_op->while_condition(); 58 HloComputation* while_body = while_op->while_body(); 59 HloInstruction* while_body_root = while_body->root_instruction(); 60 61 if (!while_init->shape().IsTuple()) { 62 VLOG(2) << "While op's carried value isn't tuple shaped."; 63 return false; 64 } 65 66 if (while_body_root->opcode() != HloOpcode::kTuple) { 67 VLOG(2) << "While body's root is not a tuple(...) instruction."; 68 return false; 69 } 70 71 auto print_no_metadata = HloPrintOptions().set_print_metadata(false); 72 73 // Bail if param0 of while_cond or while_body has users which aren't of type 74 // get-tuple-element. 75 for (const HloInstruction* instr : {while_body->parameter_instruction(0), 76 while_cond->parameter_instruction(0)}) { 77 for (const HloInstruction* user : instr->users()) { 78 if (user->opcode() != HloOpcode::kGetTupleElement) { 79 VLOG(2) << "Cowardly refusing to analyze while loop with " 80 << instr->ToString(print_no_metadata) 81 << " used by non-GTE instruction " 82 << user->ToString(print_no_metadata) << " in computation " 83 << instr->parent()->name(); 84 return false; 85 } 86 } 87 } 88 89 const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); 90 if (tuple_size == 0) { 91 VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " 92 "empty."; 93 return false; 94 } 95 96 absl::flat_hash_set<int64> used_tuple_indices; 97 for (HloComputation* comp : {while_body, while_cond}) { 98 // The HLO verifier ensures that while_input's shape matches while_init's 99 // shape, which we verified above is a tuple. 100 HloInstruction* while_input = comp->parameter_instruction(0); 101 102 for (const HloInstruction* user : while_input->users()) { 103 // This user doesn't count if it's only used by the while body's root, and 104 // the root places the tuple element into the same index of the tuple as 105 // it came from. That just amounts to us carrying the variable through 106 // the loop. 107 // 108 // Careful: HloInstruction::operand_index returns the first index the 109 // operand appears in, but it may appear more than once! 110 if (user->user_count() == 1 && user->users().front() == while_body_root && 111 while_body_root->operand_index(user) == user->tuple_index() && 112 absl::c_count(while_body_root->operands(), user) == 1) { 113 continue; 114 } 115 116 used_tuple_indices.insert(user->tuple_index()); 117 if (used_tuple_indices.size() == tuple_size) { 118 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) 119 << " uses all of its inputs; no simplification possible."; 120 return false; 121 } 122 } 123 } 124 125 // If a tuple element is not passed unmodified from the while body's param0 126 // through to the while body's root, count that element as "used", since 127 // removing that element would be observable. 128 for (int64 i = 0; i < while_body_root->operand_count(); ++i) { 129 if (used_tuple_indices.contains(i)) { 130 continue; 131 } 132 133 auto* operand = while_body_root->operand(i); 134 if (operand->opcode() != HloOpcode::kGetTupleElement || 135 operand->operand(0) != while_body->parameter_instruction(0) || 136 operand->tuple_index() != i) { 137 VLOG(2) << "Tuple index " << i 138 << " is not passed through loop body unmodified."; 139 used_tuple_indices.insert(i); 140 141 if (used_tuple_indices.size() == tuple_size) { 142 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) 143 << " uses all of its inputs; no simplification possible."; 144 return false; 145 } 146 } 147 } 148 149 // If we got here, used_tuple_indices.size() < tuple_size, meaning some 150 // elements of the loop's tuple aren't used by while_body or while_cond. 151 CHECK_LT(used_tuple_indices.size(), tuple_size); 152 153 VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() 154 << " elements from tuple of " 155 << while_op->ToString(print_no_metadata); 156 157 // Build up maps from the old/new to the new/old tuple indices. 158 std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(), 159 used_tuple_indices.end()); 160 absl::c_sort(new_to_old_tuple_idx); 161 162 absl::flat_hash_map<int64, int64> old_to_new_tuple_idx; 163 for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { 164 int64 old_idx = new_to_old_tuple_idx[new_idx]; 165 old_to_new_tuple_idx[old_idx] = new_idx; 166 VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx; 167 } 168 169 // Compute the shape of the while op after we remove the dead indices. 170 std::vector<Shape> new_while_tuple_elem_shapes; 171 new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size()); 172 for (int64 old_idx : new_to_old_tuple_idx) { 173 new_while_tuple_elem_shapes.push_back( 174 while_init->shape().tuple_shapes(old_idx)); 175 } 176 Shape new_while_shape = 177 ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes); 178 179 // Returns a map from elements in the computation to new instructions which 180 // replace the old instructions after we remove unused elements from the while 181 // tuple. 182 auto make_while_computation_replacements = [&](const HloComputation* comp) { 183 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 184 replacements; 185 186 auto* param = comp->parameter_instruction(0); 187 replacements.emplace(param, HloInstruction::CreateParameter( 188 0, new_while_shape, param->name())); 189 190 // Materialize param's users, since we're about to add new ones below. 191 std::vector<HloInstruction*> materialized_users(param->users().begin(), 192 param->users().end()); 193 for (const auto* user : materialized_users) { 194 // The while body root is handled separately. 195 if (user == while_body_root) { 196 continue; 197 } 198 CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) 199 << user->ToString(print_no_metadata); 200 201 int64 old_idx = user->tuple_index(); 202 auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); 203 if (new_idx_iter != old_to_new_tuple_idx.end()) { 204 // This is a GTE of an index that survives. Replace it. 205 replacements.emplace( 206 user, HloInstruction::CreateGetTupleElement(user->shape(), param, 207 new_idx_iter->second)); 208 } else { 209 // This is a GTE of an index that we've removed. Remove it from the 210 // cloned computation. 211 CHECK(user->user_count() == 0 || 212 user->user_count() == 1 && 213 user->users().front() == while_body_root) 214 << "Instruction " << user->ToString(print_no_metadata) 215 << " should be unused (except by root of while body), but has " 216 "users: {" 217 << absl::StrJoin(user->users(), ", ", 218 [&](string* out, const HloInstruction* instr) { 219 absl::StrAppend( 220 out, instr->ToString(print_no_metadata)); 221 }) 222 << "}"; 223 224 replacements.emplace(user, nullptr); 225 } 226 } 227 return replacements; 228 }; 229 230 // Create the new while condition, body, and init value. 231 std::unique_ptr<HloComputation> new_while_cond = 232 while_cond->CloneWithReplacements( 233 make_while_computation_replacements(while_cond)); 234 235 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 236 while_body_replacements = make_while_computation_replacements(while_body); 237 std::vector<HloInstruction*> new_while_body_root_elems; 238 new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); 239 for (int64 old_idx : new_to_old_tuple_idx) { 240 new_while_body_root_elems.push_back( 241 while_body_root->mutable_operand(old_idx)); 242 } 243 while_body_replacements.emplace( 244 while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); 245 std::unique_ptr<HloComputation> new_while_body = 246 while_body->CloneWithReplacements(std::move(while_body_replacements)); 247 248 // Add a new while_init instruction that repackages the old while_init 249 // instruction's elements. We rely on the AlgebraicSimplifier and DCE to 250 // clean this up in the common case where while_init is a tuple op. (It's 251 // definitely tuple-shaped, but it's not necessarily a tuple op.) 252 std::vector<HloInstruction*> new_while_init_elems; 253 new_while_init_elems.reserve(new_to_old_tuple_idx.size()); 254 for (int64 old_idx : new_to_old_tuple_idx) { 255 new_while_init_elems.push_back( 256 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 257 while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); 258 } 259 auto* new_while_init = computation->AddInstruction( 260 HloInstruction::CreateTuple(new_while_init_elems)); 261 262 // Create the new while op. 263 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile( 264 new_while_shape, 265 module->AddEmbeddedComputation(std::move(new_while_cond)), 266 module->AddEmbeddedComputation(std::move(new_while_body)), 267 new_while_init)); 268 269 // Create a tuple op that recreates the output of the old while op. That is, 270 // we transform to 271 // 272 // new_while_init while_init 273 // | | 274 // V | 275 // new_while | 276 // | | 277 // -------| |---- 278 // V V 279 // new_tuple 280 // | 281 // V 282 // (orig. users of while op) 283 // 284 // The tuple simplifier will then simplify this if possible, removing 285 // new_tuple and while_init. 286 std::vector<HloInstruction*> new_tuple_elems; 287 for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { 288 auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); 289 if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { 290 int64 gte_idx = new_tuple_idx_it->second; 291 new_tuple_elems.push_back( 292 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 293 new_while_op->shape().tuple_shapes(gte_idx), new_while_op, 294 gte_idx))); 295 } else { 296 new_tuple_elems.push_back( 297 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 298 while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); 299 } 300 } 301 HloInstruction* new_tuple = 302 computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); 303 TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple)); 304 305 return true; 306 } 307 308 // Removes each loop parameter (i.e. member of the while loop tuple) that is a 309 // constant and is the same in the while loop body and the while loop init. 310 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) { 311 HloModule* module = while_op->GetModule(); 312 HloComputation* computation = while_op->parent(); 313 auto* while_init = while_op->mutable_operand(0); 314 auto* while_body = while_op->while_body(); 315 auto* while_cond = while_op->while_condition(); 316 auto* while_body_root = while_body->root_instruction(); 317 if (while_init->opcode() != HloOpcode::kTuple || 318 while_body_root->opcode() != HloOpcode::kTuple) { 319 return false; 320 } 321 322 TF_RET_CHECK(while_cond->num_parameters() == 1); 323 TF_RET_CHECK(while_body->num_parameters() == 1); 324 TF_RET_CHECK( 325 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); 326 327 absl::flat_hash_set<int64> constant_tuple_indices; 328 const auto& while_shape = while_init->shape(); 329 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 330 auto* init_elem = while_init->operand(i); 331 auto* body_elem = while_body_root->operand(i); 332 if (init_elem->opcode() == HloOpcode::kConstant && 333 body_elem->opcode() == HloOpcode::kConstant && 334 init_elem->literal() == body_elem->literal()) { 335 constant_tuple_indices.insert(i); 336 } 337 } 338 339 if (constant_tuple_indices.empty()) { 340 return false; 341 } 342 343 // OK, we found some constant elements of the while parameter! Eliminate 344 // them. 345 std::vector<Shape> new_while_shape_elems; 346 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 347 if (!constant_tuple_indices.count(i)) { 348 new_while_shape_elems.push_back(while_shape.tuple_shapes(i)); 349 } 350 } 351 Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems); 352 353 // `new_instrs` holds instructions created outside of a computation for 354 // cloning. Elements added here just need to live until the end of the 355 // relevant CloneWithReplacement call. 356 std::vector<std::unique_ptr<HloInstruction>> new_instrs; 357 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) { 358 new_instrs.push_back(std::move(instr)); 359 return new_instrs.back().get(); 360 }; 361 362 // Returns a new tuple without the elements of constant_tuple_indices. 363 auto remove_constant_elems = [&](HloInstruction* instr) { 364 CHECK(ShapeUtil::Compatible(instr->shape(), while_shape)); 365 366 std::vector<HloInstruction*> tuple_elems; 367 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 368 if (!constant_tuple_indices.count(i)) { 369 tuple_elems.push_back( 370 add_new_instr(HloInstruction::CreateGetTupleElement( 371 while_shape.tuple_shapes(i), instr, i))); 372 } 373 } 374 return HloInstruction::CreateTuple(tuple_elems); 375 }; 376 377 auto add_constant_elems = [&](HloInstruction* instr) { 378 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); 379 380 std::vector<HloInstruction*> tuple_elems; 381 int64 j = 0; 382 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 383 if (constant_tuple_indices.count(i)) { 384 tuple_elems.push_back(while_init->mutable_operand(i)); 385 } else { 386 tuple_elems.push_back( 387 add_new_instr(HloInstruction::CreateGetTupleElement( 388 while_shape.tuple_shapes(i), instr, j))); 389 ++j; 390 } 391 } 392 return HloInstruction::CreateTuple(tuple_elems); 393 }; 394 395 // Special case: constant_tuple_indices covers the whole while parameter, so 396 // the new while shape is the empty tuple. In this case, the value of the 397 // while loop is simply equal to the value of `init`. 398 // 399 // It's unfortunate to special-case this, but it's simpler than the 400 // alternative. The problem is that if our while parameter has no 401 // non-constant elems, the tuple returned by `add_constant_elems` won't depend 402 // on instr (the loop body/cond parameter), and therefore 403 // CloneWithReplacementPairs will *leave the parameter out entirely*, creating 404 // invalid HLO. 405 if (ShapeUtil::IsEmptyTuple(new_while_shape)) { 406 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); 407 return true; 408 } 409 410 std::unique_ptr<HloComputation> new_while_cond = 411 while_cond->CloneWithReplacementPairs({ 412 while_cond->parameter_instruction(0), 413 add_constant_elems(add_new_instr(HloInstruction::CreateParameter( 414 0, new_while_shape, 415 while_cond->parameter_instruction(0)->name()))), 416 }); 417 418 std::unique_ptr<HloComputation> new_while_body = 419 while_body->CloneWithReplacementPairs( 420 { 421 while_body->parameter_instruction(0), 422 add_constant_elems(add_new_instr(HloInstruction::CreateParameter( 423 0, new_while_shape, 424 while_cond->parameter_instruction(0)->name()))), 425 }, 426 { 427 while_body->root_instruction(), 428 remove_constant_elems( 429 add_new_instr(while_body->root_instruction()->Clone())), 430 }); 431 432 // Create the final while loop, and add any new instructions created to 433 // `computation`. 434 new_instrs.clear(); 435 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( 436 while_op, 437 add_constant_elems( 438 computation->AddInstruction(HloInstruction::CreateWhile( 439 new_while_shape, 440 module->AddEmbeddedComputation(std::move(new_while_cond)), 441 module->AddEmbeddedComputation(std::move(new_while_body)), 442 add_new_instr(remove_constant_elems(while_init))))))); 443 for (auto& instr : new_instrs) { 444 computation->AddInstruction(std::move(instr)); 445 } 446 return true; 447 } 448 449 // Tries to remove a while loop from the graph. 450 // 451 // - Loops with trip count of 0 can be replaced by the loop's "init" value. 452 // - Loops with trip count of 1 can be replaced by the loop's body, with the 453 // loop itself removed. 454 // 455 // Returns true if it made a change to the graph. 456 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) { 457 // Cowardly refuse to remove loops that are not removable. In practice, 458 // this means that we can't remove loops that contain side-effecting 459 // instructions or have control predecessors/successors. 460 // 461 // This is not a fundamental limitation. The control operands can be moved 462 // onto the new HLOs after simplification, and any side-effecting ops inside 463 // the loop aren't removed, just cloned and added back to the loop. But 464 // moving an op out of the loop also removes implicit control dependencies 465 // between the op and the ops outside the loop, so we'd have to add those back 466 // for things like infeed/outfeed. It gets complicated. So for now we just 467 // avoid it. 468 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { 469 VLOG(2) << "Not attempting to remove while loop it is not removable: " 470 << while_op->ToShortString(); 471 return false; 472 } 473 474 // Remove while loops with static trip count of 0. 475 optional<int64> trip_count = 476 ComputeWhileLoopTripCount(while_op, 477 /*max_value_returned=*/1); 478 if (trip_count && *trip_count == 0) { 479 // The loop never executes, so the value of the loop is the value of its 480 // "init" operand. 481 auto computation = while_op->parent(); 482 483 // Remove while_op (i.e., call ReplaceInstruction rather than 484 // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in 485 // a loop without an intervening DCE, we don't try to re-remove the loop. 486 TF_RETURN_IF_ERROR(computation->ReplaceInstruction( 487 while_op, while_op->mutable_operand(0))); 488 return true; 489 } 490 491 // Transform while loops with static trip count of 1 into a call op, then 492 // inline the call. 493 if (trip_count && *trip_count == 1) { 494 auto computation = while_op->parent(); 495 auto call_op = computation->AddInstruction(HloInstruction::CreateCall( 496 while_op->shape(), while_op->operands(), while_op->while_body())); 497 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); 498 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, 499 CallInliner::Inline(call_op)); 500 (void)inlined_instructions_map; 501 return true; 502 } 503 return false; 504 } 505 506 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) { 507 auto while_init = while_op->operand(0); 508 if (while_init->opcode() != HloOpcode::kTuple) { 509 return false; 510 } 511 512 auto while_body = while_op->while_body(); 513 auto while_body_root = while_body->root_instruction(); 514 if (while_body_root->opcode() != HloOpcode::kTuple) { 515 return false; 516 } 517 518 auto while_body_param = while_body->parameter_instruction(0); 519 const HloInstruction::InstructionVector& root_operands = 520 while_body_root->operands(); 521 522 // Find the loop invariant tuple elements with scalar constant init value and 523 // build a map from the tuple element index to the constant value. Limit this 524 // to scalar constant values because propagating array constants can regress 525 // performance by forcing us to copy constants. 526 absl::flat_hash_map<int, const HloInstruction*> index_to_constant; 527 for (int i = 0; i < root_operands.size(); i++) { 528 const HloInstruction* init_tuple_elem = nullptr; 529 if (Match(root_operands[i], 530 m::GetTupleElement(m::Op().Is(while_body_param), i) 531 .WithShape(m::Shape().IsScalar())) && 532 Match(while_init->operand(i), m::Constant(&init_tuple_elem))) { 533 VLOG(3) << "Found loop invariant tuple element " << i << " " 534 << init_tuple_elem->ToString(); 535 index_to_constant[i] = init_tuple_elem; 536 } 537 } 538 539 if (index_to_constant.empty()) { 540 return false; 541 } 542 543 // Replace the use of each constant tuple element in the loop_condition and 544 // loop_body with the corresponding constant value. 545 auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> { 546 HloInstruction* param = computation->parameter_instruction(0); 547 bool changed = false; 548 for (auto instr : param->users()) { 549 // Since only a while-loop with a tuple result reaches here, we can safely 550 // assume that `param` is a tuple and the first operand of the 551 // GetTupleElement instruction is a use of `param`. 552 if (instr->opcode() == HloOpcode::kGetTupleElement) { 553 VLOG(3) << "tuple index " << instr->tuple_index() << " " 554 << instr->ToString(); 555 auto iter = index_to_constant.find(instr->tuple_index()); 556 if (iter != index_to_constant.end()) { 557 const HloInstruction* hlo_constant = (*iter).second; 558 VLOG(3) << "Replace use of " << instr->ToString() << " with " 559 << hlo_constant->ToString(); 560 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( 561 computation->AddInstruction(hlo_constant->Clone()))); 562 changed = true; 563 } 564 } 565 } 566 return changed; 567 }; 568 569 TF_ASSIGN_OR_RETURN(bool changed_cond, 570 propagate_constant(while_op->while_condition())); 571 TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); 572 573 return changed_cond || changed_body; 574 } 575 576 // Converts a flat list of instructions into a tuple of the desired shape. For 577 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns 578 // a tuple of value ((A, B), C). 579 // 580 // desired_shape must be a tuple. (This precondition allows us to return a 581 // unique_ptr rather than a raw ptr.) 582 static std::unique_ptr<HloInstruction> UnflattenTupleInstr( 583 absl::Span<HloInstruction*> instrs, const Shape& desired_shape, 584 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) { 585 CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape); 586 587 // For each child shape in `desired_shape`, slice out the correct number of 588 // `instrs` and call UnflattenTupleInstr recursively. At each step we remove 589 // elements from `instrs` so that it only contains instructions we have not 590 // yet processed. 591 std::vector<HloInstruction*> elems; 592 for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { 593 const Shape& subshape = desired_shape.tuple_shapes(i); 594 if (!subshape.IsTuple()) { 595 elems.push_back(instrs[0]); 596 instrs.remove_prefix(1); 597 continue; 598 } 599 600 // Count the number of leaf nodes underneath desired_shape[i]. 601 int64 num_leaves = 0; 602 ShapeUtil::ForEachSubshape( 603 subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { 604 if (!s.IsTuple()) { 605 ++num_leaves; 606 } 607 }); 608 609 std::unique_ptr<HloInstruction> subinstr = 610 UnflattenTupleInstr(instrs.subspan(0, num_leaves), 611 desired_shape.tuple_shapes(i), new_instrs); 612 elems.push_back(subinstr.get()); 613 new_instrs->push_back(std::move(subinstr)); 614 instrs.remove_prefix(num_leaves); 615 } 616 return HloInstruction::CreateTuple(elems); 617 } 618 619 // Builds a vector whose elements are the values in the flattened tuple for 620 // `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the 621 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C). 622 static std::vector<HloInstruction*> GetFlatTupleElems( 623 HloInstruction* instr, 624 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) { 625 const auto& shape = instr->shape(); 626 if (!shape.IsTuple()) { 627 return {instr}; 628 } 629 std::vector<HloInstruction*> elems; 630 for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { 631 const Shape& subshape = shape.tuple_shapes(i); 632 new_instrs->push_back( 633 HloInstruction::CreateGetTupleElement(subshape, instr, i)); 634 auto* gte = new_instrs->back().get(); 635 auto flattened_subshape = GetFlatTupleElems(gte, new_instrs); 636 elems.insert(elems.end(), flattened_subshape.begin(), 637 flattened_subshape.end()); 638 } 639 return elems; 640 } 641 642 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) { 643 HloModule* module = while_op->GetModule(); 644 HloComputation* computation = while_op->parent(); 645 auto* while_init = while_op->mutable_operand(0); 646 auto* while_body = while_op->while_body(); 647 auto* while_cond = while_op->while_condition(); 648 auto* while_body_root = while_body->root_instruction(); 649 if (while_init->opcode() != HloOpcode::kTuple || 650 while_body_root->opcode() != HloOpcode::kTuple) { 651 return false; 652 } 653 654 TF_RET_CHECK(while_cond->num_parameters() == 1); 655 TF_RET_CHECK(while_body->num_parameters() == 1); 656 TF_RET_CHECK( 657 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); 658 Shape while_shape = while_init->shape(); 659 if (!ShapeUtil::IsNestedTuple(while_shape)) { 660 return false; 661 } 662 663 std::vector<Shape> flattened_shape_elems; 664 ShapeUtil::ForEachSubshape(while_shape, 665 [&](const Shape& s, const ShapeIndex& /*index*/) { 666 if (!s.IsTuple()) { 667 flattened_shape_elems.push_back(s); 668 } 669 }); 670 Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems); 671 672 // `new_instrs` holds instructions created outside of a computation for 673 // cloning. Elements added here just need to live until the end of the 674 // relevant CloneWithReplacement call. 675 std::vector<std::unique_ptr<HloInstruction>> new_instrs; 676 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) { 677 new_instrs.push_back(std::move(instr)); 678 return new_instrs.back().get(); 679 }; 680 681 auto nested = [&](HloInstruction* instr) { 682 std::vector<HloInstruction*> gtes; 683 const Shape& flat_shape = instr->shape(); 684 for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) { 685 gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement( 686 flat_shape.tuple_shapes(i), instr, i))); 687 } 688 auto nested_instr = 689 UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs); 690 CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape)) 691 << ShapeUtil::HumanString(nested_instr->shape()) << " vs " 692 << ShapeUtil::HumanString(while_shape); 693 return nested_instr; 694 }; 695 696 auto flattened = [&](HloInstruction* instr) { 697 return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs)); 698 }; 699 700 // Create a new while-condition computation, where parameter 0 has flat shape 701 // but all uses of it go through the nested shape. 702 std::unique_ptr<HloComputation> new_while_cond = 703 while_cond->CloneWithReplacementPairs({ 704 while_cond->parameter_instruction(0), 705 nested(add_new_instr(HloInstruction::CreateParameter( 706 0, flattened_shape, 707 while_cond->parameter_instruction(0)->name()))), 708 }); 709 710 // Create a new while-body computation, where parameter 0 has a flat shape and 711 // all uses of it go through the nested shape, and where the root has a flat 712 // shape constructed from the old nested root. 713 std::unique_ptr<HloComputation> new_while_body = 714 while_body->CloneWithReplacementPairs( 715 { 716 while_body->parameter_instruction(0), 717 nested(add_new_instr(HloInstruction::CreateParameter( 718 0, flattened_shape, 719 while_body->parameter_instruction(0)->name()))), 720 }, 721 { 722 while_body->root_instruction(), 723 flattened(add_new_instr(while_body->root_instruction()->Clone())), 724 }); 725 726 // Create the final while loop, and add any new instructions created to 727 // `computation`. 728 new_instrs.clear(); 729 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( 730 while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile( 731 flattened_shape, 732 module->AddEmbeddedComputation(std::move(new_while_cond)), 733 module->AddEmbeddedComputation(std::move(new_while_body)), 734 computation->AddInstruction(flattened(while_init))))))); 735 for (auto& instr : new_instrs) { 736 computation->AddInstruction(std::move(instr)); 737 } 738 return true; 739 } 740 741 // Tries to merge loop induction variables of a given type. 742 // 743 // In this pass we're only concerned with elements of the loop's tuple that 744 // are effective-scalars of type `elem_ty`. Some terminology: 745 // 746 // - The trip counter is the first element of the loop's tuple that starts at 747 // 0 and does x++ on each iteration. 748 // 749 // - An induction variable is an element of the loop's tuple that is not the 750 // trip counter and does `x += <constant>` on each iteration of the loop. 751 // Negative constants are OK. 752 // 753 // This pass adds a trip counter if one isn't already present, then replaces 754 // each induction variable with 755 // 756 // <initial_value> + <trip_count> * <constant>. 757 // 758 // This reduces the number of scalar operations in the loop, which is important 759 // e.g. on GPUs, where each scalar operation is nontrivially expensive because 760 // it's a separate kernel launch. 761 // 762 // Returns the new loop if a change was made, or null if no change was made. 763 // Note that the new loop is not a valid replacement for the old loop; it may 764 // need to be wrapped in a tuple that changes its shape. We return the loop 765 // itself so that you can call TryMergeInductionVariables in a loop, once for 766 // each integral type elem_ty. 767 static StatusOr<HloInstruction*> TryMergeInductionVariables( 768 HloInstruction* while_op, PrimitiveType elem_ty) { 769 CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); 770 HloModule* module = while_op->GetModule(); 771 HloComputation* computation = while_op->parent(); 772 auto* while_init = while_op->mutable_operand(0); 773 auto* while_body = while_op->while_body(); 774 auto* while_cond = while_op->while_condition(); 775 auto* while_body_root = while_body->root_instruction(); 776 if (while_init->opcode() != HloOpcode::kTuple || 777 while_body_root->opcode() != HloOpcode::kTuple) { 778 return nullptr; 779 } 780 781 TF_RET_CHECK(while_cond->num_parameters() == 1); 782 TF_RET_CHECK(while_body->num_parameters() == 1); 783 TF_RET_CHECK( 784 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); 785 Shape while_shape = while_init->shape(); 786 787 // The tuple index of the trip counter, if one is present. 788 absl::optional<int64> trip_counter; 789 // Maps the tuple index of each induction variable to its constant increment. 790 absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars; 791 for (int64 i = 0; i < while_body_root->operand_count(); ++i) { 792 HloInstruction* constant; 793 if (!Match(while_body_root->mutable_operand(i), 794 m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i), 795 m::ConstantScalar(&constant)) 796 .WithShape(m::Shape().WithElementType(elem_ty)))) { 797 continue; 798 } 799 if (!trip_counter && constant->literal().IsAll(1) && 800 while_init->operand(i)->IsConstant() && 801 while_init->operand(i)->literal().IsAll(0)) { 802 VLOG(10) << "Found existing trip counter at index " << i; 803 trip_counter = i; 804 } else { 805 VLOG(10) << "Found induction variable at index " << i; 806 induction_vars.emplace(i, Cast<HloConstantInstruction>(constant)); 807 } 808 } 809 810 // There's only something to simplify if we can either: 811 // 812 // - combine one or more induction vars with an existing trip counter, or 813 // - replace two or more induction variables with a new trip counter. 814 // 815 // Put another way, there's only something to simplify if the number of 816 // induction vars plus the number of existing trip counters (0 or 1) is >= 2. 817 if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) { 818 return nullptr; 819 } 820 821 // OK, we're going to do the transformation! Set up some helpers. 822 823 // `new_instrs` holds instructions created outside of a computation for 824 // cloning. Elements added here just need to live until the end of the 825 // relevant CloneWithReplacement call. 826 std::vector<std::unique_ptr<HloInstruction>> new_instrs; 827 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) { 828 new_instrs.push_back(std::move(instr)); 829 return new_instrs.back().get(); 830 }; 831 832 auto add_binary_op = [&](const Shape& shape, HloOpcode opcode, 833 HloInstruction* lhs, HloInstruction* rhs) { 834 // Reshape lhs/rhs to the output shape if necessary. This deals with the 835 // fact that induction variables need only be effective scalars, not true 836 // scalars. 837 if (!ShapeUtil::Compatible(shape, lhs->shape())) { 838 lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs)); 839 } 840 if (!ShapeUtil::Compatible(shape, rhs->shape())) { 841 rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs)); 842 } 843 return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs)); 844 }; 845 846 auto add_gte = [&](HloInstruction* src, int64 idx) { 847 return add_new_instr(HloInstruction::CreateGetTupleElement( 848 src->shape().tuple_shapes(idx), src, idx)); 849 }; 850 851 // Our new while loop will have the same shape as the old while loop, except 852 // we'll add a trip counter to the end if it wasn't originally present. 853 Shape new_while_shape = while_shape; 854 bool added_trip_counter = false; 855 if (!trip_counter) { 856 VLOG(10) << "Adding new trip counter to end of loop's tuple."; 857 trip_counter = new_while_shape.tuple_shapes_size(); 858 *new_while_shape.add_tuple_shapes() = 859 ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{}); 860 added_trip_counter = true; 861 } 862 863 // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with 864 // shape `while_body->shape()` and where the induction variables are "reified" 865 // (i.e. they have value <init> + <counter> * <constant>). 866 auto convert_to_old_form = [&](HloInstruction* instr) { 867 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); 868 std::vector<HloInstruction*> tuple_elems; 869 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 870 const auto& elem_shape = while_shape.tuple_shapes(i); 871 if (!induction_vars.count(i)) { 872 tuple_elems.push_back(add_gte(instr, i)); 873 continue; 874 } 875 tuple_elems.push_back(add_binary_op( 876 elem_shape, HloOpcode::kAdd, add_gte(instr, i), 877 add_binary_op(elem_shape, HloOpcode::kMultiply, 878 add_gte(instr, *trip_counter), 879 add_new_instr(induction_vars.at(i)->Clone())))); 880 } 881 return HloInstruction::CreateTuple(tuple_elems); 882 }; 883 884 // Converts `root` into a tuple of the "new" form -- that is, to a tuple with 885 // shape `new_while_shape` and where the induction variables (but not trip 886 // counters) are replaced with their unchanging <loop_body_param> values. 887 auto convert_to_new_form = [&](HloInstruction* old_root, 888 HloParameterInstruction* loop_body_param) { 889 CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape)); 890 std::vector<HloInstruction*> tuple_elems; 891 892 // In the new form, induction variables come from `init`, everything else 893 // (including the trip counter if it's not one we created ourselves) comes 894 // from the `root` tuple unmodified. 895 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 896 tuple_elems.push_back( 897 add_gte((induction_vars.count(i) ? loop_body_param : old_root), i)); 898 } 899 // If we created a trip counter ourselves, add 1 to it in the next 900 // iteration. 901 if (added_trip_counter) { 902 tuple_elems.push_back(add_binary_op( 903 new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd, 904 add_gte(loop_body_param, *trip_counter), 905 add_new_instr( 906 HloInstruction::CreateConstant(LiteralUtil::One(elem_ty))))); 907 } 908 909 return HloInstruction::CreateTuple(tuple_elems); 910 }; 911 912 // Creates a new init tuple, which is the same as the old init tuple except if 913 // we added a trip counter, it's set to 0. 914 auto get_new_while_init = [&](HloInstruction* init) { 915 CHECK(ShapeUtil::Compatible(init->shape(), while_shape)); 916 if (!added_trip_counter) { 917 return init; 918 } 919 std::vector<HloInstruction*> tuple_elems; 920 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { 921 tuple_elems.push_back(add_gte(init, i)); 922 } 923 tuple_elems.push_back(add_new_instr( 924 HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty)))); 925 return add_new_instr(HloInstruction::CreateTuple(tuple_elems)); 926 }; 927 928 std::unique_ptr<HloComputation> new_while_cond = 929 while_cond->CloneWithReplacementPairs({ 930 while_cond->parameter_instruction(0), 931 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( 932 0, new_while_shape, 933 while_cond->parameter_instruction(0)->name()))), 934 }); 935 936 // Creating the new while body proceeds in two steps. First we convert the 937 // users of the parameter to the old form. Then as a second 938 // CloneWithReplacement operation we convert the root to the new form. We 939 // have to do this in two steps because the new root needs to use the new 940 // param0, and during the first clone operation, only the *old-form* param0 is 941 // accessible. 942 // 943 // We have to add temp_new_while_body to the module because cloning a 944 // computation touches the module (to get its NameUniquer). 945 HloComputation* temp_new_while_body = 946 module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({ 947 while_body->parameter_instruction(0), 948 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( 949 0, new_while_shape, 950 while_body->parameter_instruction(0)->name()))), 951 })); 952 std::unique_ptr<HloComputation> new_while_body = 953 temp_new_while_body->CloneWithReplacementPairs({ 954 temp_new_while_body->root_instruction(), 955 convert_to_new_form( 956 add_new_instr(temp_new_while_body->root_instruction()->Clone()), 957 Cast<HloParameterInstruction>( 958 temp_new_while_body->parameter_instruction(0))), 959 }); 960 TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); 961 962 // Create the final while loop, and add any new instructions created to 963 // `computation`. 964 new_instrs.clear(); 965 auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile( 966 new_while_shape, 967 module->AddEmbeddedComputation(std::move(new_while_cond)), 968 module->AddEmbeddedComputation(std::move(new_while_body)), 969 get_new_while_init(while_init))); 970 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( 971 while_op, convert_to_old_form(new_while))); 972 for (auto& instr : new_instrs) { 973 computation->AddInstruction(std::move(instr)); 974 } 975 return new_while; 976 } 977 978 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) { 979 XLA_VLOG_LINES(3, 980 "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); 981 bool changed = false; 982 983 // Gather all the while ops in our module. We do this ahead of time so we 984 // don't have to worry about mutating the lists of computations or 985 // instructions while we iterate. 986 std::vector<HloInstruction*> while_ops; 987 for (auto* comp : module->computations()) { 988 for (auto* instr : comp->instructions()) { 989 if (instr->opcode() == HloOpcode::kWhile) { 990 while_ops.push_back(instr); 991 } 992 } 993 } 994 995 for (HloInstruction* while_op : while_ops) { 996 // We can't remove while loops that contain send/recv nodes, because we rely 997 // on the particular loop structure around the node matching on the send and 998 // recv sides. Other while simplifications require us to remove the loop 999 // and replace it with a new one, so we can't do that either. 1000 if (ContainsInstrWithOpcode(while_op->while_body(), 1001 {HloOpcode::kSend, HloOpcode::kSendDone, 1002 HloOpcode::kRecv, HloOpcode::kRecvDone}) || 1003 ContainsInstrWithOpcode(while_op->while_condition(), 1004 {HloOpcode::kSend, HloOpcode::kSendDone, 1005 HloOpcode::kRecv, HloOpcode::kRecvDone})) { 1006 VLOG(2) << "Not attempting to simplify while loop because it contains a " 1007 "send/recv node: " 1008 << while_op->ToShortString(); 1009 continue; 1010 } 1011 1012 TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); 1013 changed |= result; 1014 1015 TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); 1016 changed |= result; 1017 if (result) { 1018 // Don't continue simplifying after successfully removing the while loop 1019 // -- that would result in use-after-free nastiness. 1020 continue; 1021 } 1022 1023 // TODO(b/119281462): Cowardly refuse to perform any of the following 1024 // optimizations in the presence of kDomain instructions. It seems that 1025 // modifying a while loop's tuple doesn't work when kDomain is present. 1026 if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) || 1027 ContainsInstrWithOpcode(while_op->while_condition(), 1028 {HloOpcode::kDomain})) { 1029 continue; 1030 } 1031 1032 // Each of the optimizations below modifies the while loop itself if it's 1033 // successful, meaning that `while_op` is no longer valid after one of these 1034 // transformations returns true. 1035 1036 TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); 1037 changed |= result; 1038 if (result) { 1039 continue; 1040 } 1041 1042 TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); 1043 changed |= result; 1044 if (result) { 1045 continue; 1046 } 1047 1048 TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); 1049 changed |= result; 1050 if (result) { 1051 continue; 1052 } 1053 1054 bool merged_induction_vars = false; 1055 // Notably missing from this list are S16 and U16. These don't currently 1056 // work because S/U16 literals are not implemented. 1057 for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { 1058 TF_ASSIGN_OR_RETURN(auto* new_while_op, 1059 TryMergeInductionVariables(while_op, elem_ty)); 1060 if (new_while_op) { 1061 while_op = new_while_op; 1062 changed = true; 1063 merged_induction_vars = true; 1064 } 1065 } 1066 if (merged_induction_vars) { 1067 continue; 1068 } 1069 } 1070 1071 XLA_VLOG_LINES(3, 1072 "WhileLoopSimplifier::Run(), after:\n" + module->ToString()); 1073 return changed; 1074 } 1075 1076 } // namespace xla 1077