1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 19 #include "tensorflow/compiler/xla/test.h" 20 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" 21 #include "tensorflow/core/lib/core/status_test_util.h" 22 23 namespace xla { 24 namespace { 25 26 namespace op = xla::testing::opcode_matchers; 27 28 class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { 29 public: 30 // Makes a computation which has one parameter, of the given shape, and always 31 // returns PRED[]{true}. This is useful as a dummy loop condition. 32 HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, 33 HloModule* module); 34 }; 35 36 static void FindOnlyWhileInstruction(HloComputation* computation, 37 HloInstruction** while_instruction) { 38 *while_instruction = nullptr; 39 for (auto* instr : computation->instructions()) { 40 if (instr->opcode() == HloOpcode::kWhile) { 41 ASSERT_EQ(*while_instruction, nullptr); 42 *while_instruction = instr; 43 } 44 } 45 46 ASSERT_NE(*while_instruction, nullptr); 47 } 48 49 HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( 50 const Shape& param_shape, HloModule* module) { 51 HloComputation::Builder builder(TestName() + ".always_true"); 52 builder.AddInstruction( 53 HloInstruction::CreateParameter(0, param_shape, "param")); 54 builder.AddInstruction( 55 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 56 return module->AddEmbeddedComputation(builder.Build()); 57 } 58 59 TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { 60 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 61 Shape while_shape = 62 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); 63 64 HloComputation* while_body = [&]() { 65 HloComputation::Builder builder(TestName() + ".while_body"); 66 HloInstruction* param = builder.AddInstruction( 67 HloInstruction::CreateParameter(0, while_shape, "param")); 68 HloInstruction* gte_0 = builder.AddInstruction( 69 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 70 HloInstruction* gte_1 = builder.AddInstruction( 71 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 72 HloInstruction* add_result = 73 builder.AddInstruction(HloInstruction::CreateBinary( 74 scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); 75 builder.AddInstruction( 76 HloInstruction::CreateTuple({gte_0, gte_1, add_result})); 77 78 return module().AddEmbeddedComputation(builder.Build()); 79 }(); 80 81 HloComputation::Builder builder(TestName()); 82 auto* init_value = builder.AddInstruction( 83 HloInstruction::CreateParameter(0, while_shape, "init_value")); 84 builder.AddInstruction(HloInstruction::CreateWhile( 85 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 86 while_body, init_value)); 87 HloComputation* entry_computation = 88 module().AddEntryComputation(builder.Build()); 89 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 90 WhileLoopInvariantCodeMotion{}.Run(&module())); 91 EXPECT_TRUE(simplified_loop); 92 93 HloInstruction* transformed_while; 94 FindOnlyWhileInstruction(entry_computation, &transformed_while); 95 96 EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); 97 EXPECT_THAT(transformed_while->while_body()->instructions(), 98 Each(Not(op::Add()))); 99 } 100 101 TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { 102 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 103 Shape while_shape = 104 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); 105 106 HloComputation* while_body = [&]() { 107 HloComputation::Builder builder(TestName() + ".while_body"); 108 HloInstruction* param = builder.AddInstruction( 109 HloInstruction::CreateParameter(0, while_shape, "param")); 110 HloInstruction* gte_0 = builder.AddInstruction( 111 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 112 HloInstruction* gte_1 = builder.AddInstruction( 113 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 114 HloInstruction* gte_2_loop_variant = builder.AddInstruction( 115 HloInstruction::CreateGetTupleElement(scalar_s32, param, 2)); 116 117 HloInstruction* add_result = 118 builder.AddInstruction(HloInstruction::CreateBinary( 119 scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); 120 HloInstruction* mul_result = 121 builder.AddInstruction(HloInstruction::CreateBinary( 122 scalar_s32, HloOpcode::kMultiply, add_result, gte_1)); 123 HloInstruction* negate_result = 124 builder.AddInstruction(HloInstruction::CreateUnary( 125 scalar_s32, HloOpcode::kNegate, mul_result)); 126 HloInstruction* constant = builder.AddInstruction( 127 HloInstruction::CreateConstant(Literal::CreateR0<int32>(4))); 128 HloInstruction* sub_result = 129 builder.AddInstruction(HloInstruction::CreateBinary( 130 scalar_s32, HloOpcode::kSubtract, negate_result, constant)); 131 HloInstruction* divide_result = 132 builder.AddInstruction(HloInstruction::CreateBinary( 133 scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant)); 134 builder.AddInstruction( 135 HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); 136 137 return module().AddEmbeddedComputation(builder.Build()); 138 }(); 139 140 HloComputation::Builder builder(TestName()); 141 auto* init_value = builder.AddInstruction( 142 HloInstruction::CreateParameter(0, while_shape, "init_value")); 143 builder.AddInstruction(HloInstruction::CreateWhile( 144 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 145 while_body, init_value)); 146 HloComputation* entry_computation = 147 module().AddEntryComputation(builder.Build()); 148 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 149 WhileLoopInvariantCodeMotion{}.Run(&module())); 150 EXPECT_TRUE(simplified_loop); 151 152 HloInstruction* transformed_while; 153 FindOnlyWhileInstruction(entry_computation, &transformed_while); 154 155 EXPECT_THAT(entry_computation->instructions(), 156 AllOf(Contains(op::Add()), Contains(op::Multiply()), 157 Contains(op::Negate()), Contains(op::Subtract()), 158 Contains(op::Constant()), 159 160 // The division had a loop varying operand so that better 161 // not be hoisted. 162 Not(Contains(op::Divide())))); 163 164 EXPECT_THAT(transformed_while->while_body()->instructions(), 165 Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(), 166 op::Subtract(), op::Constant())))); 167 168 EXPECT_THAT(transformed_while->while_body()->instructions(), 169 Contains(op::Divide())); 170 } 171 172 TEST_F(WhileLoopInvariantCodeMotionTest, 173 DontHoistTriviallyLoopVaryingComputation) { 174 // Basic negative test: the add expression is not loop invariant. 175 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 176 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); 177 178 HloComputation* while_body = [&]() { 179 HloComputation::Builder builder(TestName() + ".while_body"); 180 HloInstruction* param = builder.AddInstruction( 181 HloInstruction::CreateParameter(0, while_shape, "param")); 182 HloInstruction* gte_0 = builder.AddInstruction( 183 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 184 HloInstruction* gte_1 = builder.AddInstruction( 185 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 186 HloInstruction* add_result = 187 builder.AddInstruction(HloInstruction::CreateBinary( 188 scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); 189 builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); 190 191 return module().AddEmbeddedComputation(builder.Build()); 192 }(); 193 194 HloComputation::Builder builder(TestName()); 195 auto* init_value = builder.AddInstruction( 196 HloInstruction::CreateParameter(0, while_shape, "init_value")); 197 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( 198 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 199 while_body, init_value)); 200 201 module().AddEntryComputation(builder.Build()); 202 203 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 204 WhileLoopInvariantCodeMotion{}.Run(&module())); 205 EXPECT_FALSE(simplified_loop); 206 207 EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); 208 } 209 210 TEST_F(WhileLoopInvariantCodeMotionTest, 211 DontHoistLoopVaryingComputationWithAlternatingTuples) { 212 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 213 Shape while_shape = 214 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); 215 216 HloComputation* while_body = [&]() { 217 HloComputation::Builder builder(TestName() + ".while_body"); 218 HloInstruction* param = builder.AddInstruction( 219 HloInstruction::CreateParameter(0, while_shape, "param")); 220 HloInstruction* gte_0 = builder.AddInstruction( 221 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 222 HloInstruction* gte_1 = builder.AddInstruction( 223 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 224 HloInstruction* add_result = 225 builder.AddInstruction(HloInstruction::CreateBinary( 226 scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); 227 builder.AddInstruction( 228 HloInstruction::CreateTuple({gte_1, gte_0, add_result})); 229 230 return module().AddEmbeddedComputation(builder.Build()); 231 }(); 232 233 HloComputation::Builder builder(TestName()); 234 auto* init_value = builder.AddInstruction( 235 HloInstruction::CreateParameter(0, while_shape, "init_value")); 236 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( 237 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 238 while_body, init_value)); 239 240 module().AddEntryComputation(builder.Build()); 241 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 242 WhileLoopInvariantCodeMotion{}.Run(&module())); 243 EXPECT_FALSE(simplified_loop); 244 245 EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); 246 } 247 248 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { 249 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 250 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); 251 252 HloComputation* while_body = [&]() { 253 HloComputation::Builder builder(TestName() + ".while_body"); 254 HloInstruction* param = builder.AddInstruction( 255 HloInstruction::CreateParameter(0, while_shape, "param")); 256 HloInstruction* gte_0 = builder.AddInstruction( 257 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 258 HloInstruction* gte_1 = builder.AddInstruction( 259 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 260 builder.AddInstruction( 261 HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); 262 builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); 263 264 return module().AddEmbeddedComputation(builder.Build()); 265 }(); 266 267 HloComputation::Builder builder(TestName()); 268 auto* init_value = builder.AddInstruction( 269 HloInstruction::CreateParameter(0, while_shape, "init_value")); 270 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( 271 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 272 while_body, init_value)); 273 274 module().AddEntryComputation(builder.Build()); 275 276 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 277 WhileLoopInvariantCodeMotion{}.Run(&module())); 278 EXPECT_FALSE(simplified_loop); 279 280 EXPECT_THAT(while_inst->while_body()->instructions(), 281 Contains(op::Outfeed())); 282 } 283 284 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { 285 // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the 286 // bitcast either. 287 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 288 auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); 289 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); 290 291 HloComputation* while_body = [&]() { 292 HloComputation::Builder builder(TestName() + ".while_body"); 293 HloInstruction* param = builder.AddInstruction( 294 HloInstruction::CreateParameter(0, while_shape, "param")); 295 HloInstruction* gte_0 = builder.AddInstruction( 296 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 297 HloInstruction* gte_1 = builder.AddInstruction( 298 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 299 HloInstruction* bitcast_inst = builder.AddInstruction( 300 HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); 301 builder.AddInstruction( 302 HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); 303 builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); 304 305 return module().AddEmbeddedComputation(builder.Build()); 306 }(); 307 308 HloComputation::Builder builder(TestName()); 309 auto* init_value = builder.AddInstruction( 310 HloInstruction::CreateParameter(0, while_shape, "init_value")); 311 auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( 312 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 313 while_body, init_value)); 314 315 module().AddEntryComputation(builder.Build()); 316 317 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 318 WhileLoopInvariantCodeMotion{}.Run(&module())); 319 EXPECT_FALSE(simplified_loop); 320 321 EXPECT_THAT(while_inst->while_body()->instructions(), 322 Contains(op::Outfeed())); 323 EXPECT_THAT(while_inst->while_body()->instructions(), 324 Contains(op::Bitcast())); 325 } 326 327 TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { 328 // The bitcast's user can be hoisted, so hoist the bitcast too. 329 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 330 auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); 331 Shape while_shape = 332 ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); 333 334 HloComputation* while_body = [&]() { 335 HloComputation::Builder builder(TestName() + ".while_body"); 336 HloInstruction* param = builder.AddInstruction( 337 HloInstruction::CreateParameter(0, while_shape, "param")); 338 HloInstruction* gte_0 = builder.AddInstruction( 339 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 340 HloInstruction* gte_1 = builder.AddInstruction( 341 HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); 342 HloInstruction* bitcast_inst = builder.AddInstruction( 343 HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); 344 HloInstruction* add_inst = 345 builder.AddInstruction(HloInstruction::CreateBinary( 346 scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); 347 builder.AddInstruction( 348 HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); 349 350 return module().AddEmbeddedComputation(builder.Build()); 351 }(); 352 353 HloComputation::Builder builder(TestName()); 354 auto* init_value = builder.AddInstruction( 355 HloInstruction::CreateParameter(0, while_shape, "init_value")); 356 builder.AddInstruction(HloInstruction::CreateWhile( 357 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 358 while_body, init_value)); 359 360 HloComputation* entry_computation = 361 module().AddEntryComputation(builder.Build()); 362 363 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 364 WhileLoopInvariantCodeMotion{}.Run(&module())); 365 EXPECT_TRUE(simplified_loop); 366 367 HloInstruction* transformed_while; 368 FindOnlyWhileInstruction(entry_computation, &transformed_while); 369 370 EXPECT_THAT(transformed_while->while_body()->instructions(), 371 Each(Not(op::Add()))); 372 EXPECT_THAT(transformed_while->while_body()->instructions(), 373 Each(Not(op::Bitcast()))); 374 EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); 375 EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast())); 376 } 377 378 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { 379 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 380 Shape while_shape = 381 ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); 382 383 HloComputation* while_body; 384 { 385 HloComputation::Builder builder(TestName() + ".while_body"); 386 HloInstruction* param = builder.AddInstruction( 387 HloInstruction::CreateParameter(0, while_shape, "param")); 388 HloInstruction* gte_0 = builder.AddInstruction( 389 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); 390 HloInstruction* gte_1 = builder.AddInstruction( 391 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 392 HloInstruction* add_result = 393 builder.AddInstruction(HloInstruction::CreateBinary( 394 scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); 395 TF_ASSERT_OK(param->AddControlDependencyTo(add_result)); 396 builder.AddInstruction( 397 HloInstruction::CreateTuple({gte_0, gte_1, add_result})); 398 399 while_body = module().AddEmbeddedComputation(builder.Build()); 400 } 401 402 HloComputation::Builder builder(TestName()); 403 auto* init_value = builder.AddInstruction( 404 HloInstruction::CreateParameter(0, while_shape, "init_value")); 405 builder.AddInstruction(HloInstruction::CreateWhile( 406 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 407 while_body, init_value)); 408 module().AddEntryComputation(builder.Build()); 409 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 410 WhileLoopInvariantCodeMotion{}.Run(&module())); 411 EXPECT_FALSE(simplified_loop); 412 } 413 414 TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { 415 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 416 Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); 417 418 HloComputation* while_body = [&]() { 419 HloComputation::Builder builder(TestName() + ".passthrough"); 420 HloInstruction* param = builder.AddInstruction( 421 HloInstruction::CreateParameter(0, while_shape, "param")); 422 HloComputation* result = module().AddEmbeddedComputation(builder.Build()); 423 424 result->AddInstruction( 425 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); 426 return result; 427 }(); 428 429 HloComputation::Builder builder(TestName()); 430 auto* init_value = builder.AddInstruction( 431 HloInstruction::CreateParameter(0, while_shape, "init_value")); 432 builder.AddInstruction(HloInstruction::CreateWhile( 433 while_shape, MakeAlwaysTrueComputation(while_shape, &module()), 434 while_body, init_value)); 435 module().AddEntryComputation(builder.Build()); 436 TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, 437 WhileLoopInvariantCodeMotion{}.Run(&module())); 438 EXPECT_FALSE(simplified_loop); 439 } 440 441 } // namespace 442 } // namespace xla 443