1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 19 #include "tensorflow/compiler/xla/test.h" 20 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" 21 #include "tensorflow/core/lib/core/status_test_util.h" 22 23 namespace xla { 24 namespace { 25 26 namespace op = xla::testing::opcode_matchers; 27 28 class WhileLoopSimplifierTest : public HloVerifiedTestBase { 29 public: 30 // Makes a computation that contains a loop that runs num_iters times. 31 HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); 32 33 // Makes a computation which has one parameter, of the given shape, and always 34 // returns PRED[]{true}. This is useful as a dummy loop condition. 35 HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, 36 HloModule* module); 37 }; 38 39 HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, 40 HloModule* module) { 41 HloComputation::Builder builder(TestName()); 42 43 auto loop_iter_init = builder.AddInstruction( 44 HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); 45 auto loop_data_init = builder.AddInstruction( 46 HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1, 2}))); 47 auto loop_init = builder.AddInstruction( 48 HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); 49 50 HloComputation* condition; 51 { 52 HloComputation::Builder cond_builder(TestName() + ".condition"); 53 auto loop_var = cond_builder.AddInstruction( 54 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 55 auto loop_induction_var = 56 cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( 57 ShapeUtil::MakeShape(S32, {}), loop_var, 0)); 58 auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( 59 Literal::CreateR0<int32>(42 + num_iters))); 60 cond_builder.AddInstruction(HloInstruction::CreateBinary( 61 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, 62 limit)); 63 condition = module->AddEmbeddedComputation(cond_builder.Build()); 64 } 65 66 HloComputation* body; 67 { 68 HloComputation::Builder body_builder(TestName() + ".body"); 69 auto loop_var = body_builder.AddInstruction( 70 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 71 auto loop_induction_var = 72 body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( 73 ShapeUtil::MakeShape(S32, {}), loop_var, 0)); 74 auto new_loop_induction_var = 75 body_builder.AddInstruction(HloInstruction::CreateBinary( 76 loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, 77 body_builder.AddInstruction( 78 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))))); 79 auto loop_data = 80 body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( 81 loop_data_init->shape(), loop_var, 1)); 82 auto new_loop_data = 83 body_builder.AddInstruction(HloInstruction::CreateBinary( 84 loop_data_init->shape(), HloOpcode::kMultiply, loop_data, 85 loop_data)); 86 body_builder.AddInstruction( 87 HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); 88 body = module->AddEmbeddedComputation(body_builder.Build()); 89 } 90 91 builder.AddInstruction(HloInstruction::CreateWhile( 92 loop_init->shape(), condition, body, loop_init)); 93 94 return module->AddEntryComputation(builder.Build()); 95 } 96 97 HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( 98 const Shape& param_shape, HloModule* module) { 99 HloComputation::Builder builder(TestName() + ".always_true"); 100 builder.AddInstruction( 101 HloInstruction::CreateParameter(0, param_shape, "param")); 102 builder.AddInstruction( 103 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 104 return module->AddEmbeddedComputation(builder.Build()); 105 } 106 107 TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { 108 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); 109 ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 110 EXPECT_THAT(computation->root_instruction(), 111 op::Tuple(op::Constant(), op::Constant())); 112 } 113 114 TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { 115 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); 116 ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 117 EXPECT_THAT(computation->root_instruction(), 118 op::Tuple(op::Add(), op::Multiply())); 119 } 120 121 TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { 122 MakeSimpleLoop(/*num_iters=*/2, &module()); 123 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 124 } 125 126 TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { 127 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); 128 auto* while_op = computation->root_instruction(); 129 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); 130 auto* true_op = while_op->while_body()->AddInstruction( 131 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 132 TF_ASSERT_OK(true_op->AddControlDependencyTo( 133 while_op->while_body()->root_instruction())); 134 ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 135 EXPECT_THAT(computation->root_instruction()->control_predecessors(), 136 ElementsAre(op::Constant())) 137 << computation->ToString(); 138 } 139 140 // Loops that contain send/recv nodes can't be simplified; the loop structure 141 // around send/recv nodes must be preserved. 142 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { 143 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); 144 auto* while_op = computation->root_instruction(); 145 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); 146 auto* while_body = while_op->while_body(); 147 auto* send = while_body->AddInstruction(HloInstruction::CreateSend( 148 while_body->AddInstruction( 149 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))), 150 /*channel_id=*/0)); 151 while_body->AddInstruction(HloInstruction::CreateSendDone(send)); 152 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 153 } 154 155 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { 156 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); 157 auto* while_op = computation->root_instruction(); 158 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); 159 auto* while_body = while_op->while_body(); 160 auto* recv = while_body->AddInstruction( 161 HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), 162 /*channel_id=*/0)); 163 while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); 164 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 165 } 166 167 // The limitation on not being able to simplify loops that contain infeeds (and 168 // other non-removable instructions) isn't fundamental -- it just stems from the 169 // fact that our infrastructure sees simplifying such a loop as tantamount to 170 // removing the non-removable instruction. 171 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { 172 HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); 173 auto* while_op = computation->root_instruction(); 174 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); 175 auto* while_body = while_op->while_body(); 176 while_body->AddInstruction( 177 HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); 178 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 179 } 180 181 // Check that we don't crash when given a loop whose shape is not a tuple. 182 TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { 183 HloComputation::Builder builder(TestName()); 184 auto loop_init = builder.AddInstruction( 185 HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); 186 187 HloComputation* condition; 188 { 189 HloComputation::Builder cond_builder(TestName() + ".condition"); 190 auto param = cond_builder.AddInstruction( 191 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 192 cond_builder.AddInstruction(HloInstruction::CreateBinary( 193 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, 194 cond_builder.AddInstruction( 195 HloInstruction::CreateConstant(Literal::CreateR0<int32>(100))))); 196 condition = module().AddEmbeddedComputation(cond_builder.Build()); 197 } 198 199 HloComputation* body; 200 { 201 HloComputation::Builder body_builder(TestName() + ".body"); 202 auto param = body_builder.AddInstruction( 203 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 204 body_builder.AddInstruction(HloInstruction::CreateBinary( 205 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, 206 body_builder.AddInstruction( 207 HloInstruction::CreateConstant(Literal::CreateR0<int32>(-1))))); 208 body = module().AddEmbeddedComputation(body_builder.Build()); 209 } 210 211 builder.AddInstruction(HloInstruction::CreateWhile( 212 loop_init->shape(), condition, body, loop_init)); 213 214 module().AddEntryComputation(builder.Build()); 215 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 216 } 217 218 // Construct a loop where we swap the tuple elements in each iteration. 219 // Although the tuple elements aren't used in the loop, we don't eliminate them, 220 // because the swapping side-effect is visible to users of the loop. 221 TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { 222 HloComputation::Builder builder(TestName()); 223 auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ 224 builder.AddInstruction( 225 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 226 builder.AddInstruction( 227 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))), 228 })); 229 230 HloComputation* condition = 231 MakeAlwaysTrueComputation(loop_init->shape(), &module()); 232 HloComputation* body; 233 { 234 HloComputation::Builder body_builder(TestName() + ".body"); 235 auto param = body_builder.AddInstruction( 236 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 237 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 238 body_builder.AddInstruction(HloInstruction::CreateTuple({ 239 body_builder.AddInstruction( 240 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), 241 body_builder.AddInstruction( 242 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), 243 })); 244 body = module().AddEmbeddedComputation(body_builder.Build()); 245 } 246 247 builder.AddInstruction(HloInstruction::CreateWhile( 248 loop_init->shape(), condition, body, loop_init)); 249 250 module().AddEntryComputation(builder.Build()); 251 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 252 } 253 254 // Construct a loop where we assign a constant to tuple element 0 in each 255 // iteration. We can't eliminate tuple element 0, even though we never use its 256 // value. 257 TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { 258 HloComputation::Builder builder(TestName()); 259 auto loop_init = builder.AddInstruction( 260 HloInstruction::CreateTuple({builder.AddInstruction( 261 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)))})); 262 263 HloComputation* condition = 264 MakeAlwaysTrueComputation(loop_init->shape(), &module()); 265 HloComputation* body; 266 { 267 HloComputation::Builder body_builder(TestName() + ".body"); 268 body_builder.AddInstruction( 269 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 270 body_builder.AddInstruction(HloInstruction::CreateTuple({ 271 body_builder.AddInstruction( 272 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))), 273 })); 274 body = module().AddEmbeddedComputation(body_builder.Build()); 275 } 276 277 builder.AddInstruction(HloInstruction::CreateWhile( 278 loop_init->shape(), condition, body, loop_init)); 279 280 module().AddEntryComputation(builder.Build()); 281 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 282 } 283 284 // Nothing to simplify in a while loop whose tuple has 0 elements. 285 TEST_F(WhileLoopSimplifierTest, EmptyTuple) { 286 HloComputation::Builder builder(TestName()); 287 auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); 288 289 HloComputation* condition = 290 MakeAlwaysTrueComputation(loop_init->shape(), &module()); 291 HloComputation* body; 292 { 293 HloComputation::Builder body_builder(TestName() + ".body"); 294 body_builder.AddInstruction( 295 HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); 296 body_builder.AddInstruction(HloInstruction::CreateTuple({})); 297 body = module().AddEmbeddedComputation(body_builder.Build()); 298 } 299 300 builder.AddInstruction(HloInstruction::CreateWhile( 301 loop_init->shape(), condition, body, loop_init)); 302 module().AddEntryComputation(builder.Build()); 303 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 304 } 305 306 // While loop where one tuple element is used twice in the body, and thus can't 307 // be simplified away. 308 TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { 309 HloComputation::Builder builder(TestName()); 310 auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ 311 builder.AddInstruction( 312 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 313 builder.AddInstruction( 314 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))), 315 })); 316 317 HloComputation* condition = 318 MakeAlwaysTrueComputation(loop_init->shape(), &module()); 319 320 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 321 HloComputation* body; 322 { 323 HloComputation::Builder body_builder(TestName() + ".body"); 324 auto* param = body_builder.AddInstruction( 325 HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); 326 auto* gte0 = body_builder.AddInstruction( 327 HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); 328 // get0 is used twice in the loop body's tuple. 329 body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); 330 body = module().AddEmbeddedComputation(body_builder.Build()); 331 } 332 333 builder.AddInstruction(HloInstruction::CreateWhile( 334 loop_init->shape(), condition, body, loop_init)); 335 module().AddEntryComputation(builder.Build()); 336 EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 337 } 338 339 // This while loop has three tuple elements. Element 0 is unused and should be 340 // removed. Element 1 is used by the loop body, and element 2 is used by the 341 // loop condition; these two should stay. 342 TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { 343 HloComputation::Builder builder(TestName()); 344 auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ 345 builder.AddInstruction( 346 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 347 builder.AddInstruction( 348 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 349 builder.AddInstruction( 350 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 351 })); 352 auto loop_shape = loop_init->shape(); 353 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 354 355 HloComputation* condition; 356 { 357 HloComputation::Builder cond_builder(TestName() + ".loop_condition"); 358 auto param = cond_builder.AddInstruction( 359 HloInstruction::CreateParameter(0, loop_shape, "param0")); 360 cond_builder.AddInstruction(HloInstruction::CreateBinary( 361 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, 362 cond_builder.AddInstruction( 363 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))), 364 cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( 365 scalar_s32, param, /*index=*/2)))); 366 condition = module().AddEmbeddedComputation(cond_builder.Build()); 367 } 368 369 HloComputation* body; 370 { 371 HloComputation::Builder body_builder(TestName() + ".body"); 372 auto* param = body_builder.AddInstruction( 373 HloInstruction::CreateParameter(0, loop_shape, "loop_var")); 374 375 auto* tuple0 = body_builder.AddInstruction( 376 HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); 377 auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( 378 scalar_s32, HloOpcode::kAdd, 379 body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( 380 scalar_s32, param, /*index=*/1)), 381 body_builder.AddInstruction( 382 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))))); 383 auto* tuple2 = body_builder.AddInstruction( 384 HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); 385 body_builder.AddInstruction( 386 HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); 387 388 body = module().AddEmbeddedComputation(body_builder.Build()); 389 } 390 391 auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( 392 loop_init->shape(), condition, body, loop_init)); 393 394 module().AddEntryComputation(builder.Build()); 395 EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); 396 397 // We leave most of the checking to HloVerifiedTestBase, which runs the 398 // verifier on module() at the end of this test. 399 HloInstruction* new_while_op = *std::find_if( 400 module().entry_computation()->instructions().begin(), 401 module().entry_computation()->instructions().end(), 402 [&](const HloInstruction* instr) { 403 return instr != while_op && instr->opcode() == HloOpcode::kWhile; 404 }); 405 EXPECT_TRUE( 406 ShapeUtil::Equal(new_while_op->shape(), 407 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) 408 << ShapeUtil::HumanString(new_while_op->shape()); 409 EXPECT_THAT( 410 new_while_op->while_body()->root_instruction(), 411 op::Tuple( 412 op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0), 413 op::Constant()), 414 op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); 415 416 EXPECT_THAT(new_while_op->while_condition()->root_instruction(), 417 op::Eq(op::Constant(), 418 op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); 419 } 420 421 TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { 422 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 423 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); 424 425 HloComputation* while_body = [&]() { 426 HloComputation::Builder builder(TestName() + ".passthrough"); 427 HloInstruction* param = builder.AddInstruction( 428 HloInstruction::CreateParameter(0, while_shape, "param")); 429 HloComputation* result = module().AddEmbeddedComputation(builder.Build()); 430 431 result->AddInstruction( 432 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 433 return result; 434 }(); 435 436 HloComputation::Builder builder(TestName()); 437 auto* init_value = builder.AddInstruction( 438 HloInstruction::CreateParameter(0, while_shape, "init_value")); 439 builder.AddInstruction(HloInstruction::CreateWhile( 440 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 441 while_body, init_value)); 442 module().AddEntryComputation(builder.Build()); 443 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 444 WhileLoopSimplifier{}.Run(&module())); 445 EXPECT_FALSE(simplified_loop); 446 } 447 448 } // namespace 449 } // namespace xla 450