1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h" 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 31 namespace xla { 32 namespace { 33 34 namespace op = xla::testing::opcode_matchers; 35 36 using ::testing::_; 37 38 class HloRematerializationTest : public HloTestBase { 39 protected: 40 // Creates and returns a computation which can benefit from 41 // rematerialization. The computation looks like: 42 // 43 // F32[] %param = {...} 44 // F32[1024] %bcast = broadcast(%param) 45 // F32[1024] %negate = negate(%bcast) 46 // F32[2048] %concat_1 = concat({%negate, %negate}) 47 // F32[1] %slice_1 = slice(%concat_1, {0:1}) 48 // F32[1025] %concat_2 = concat({%bcast, %slice_1}) 49 // F32[1] %slice_2 = slice(%concat_2, {0:1}); 50 // 51 // The instruction %bcast can be rematerialized before its use at %concat_2 52 // to reduce peak memory usage. This avoids %bcast and %concat_1 being 53 // simultaneously live. Peak memory use is about 16KB before rematerialization 54 // (during execution of %concat_1) and about 12KB after rematerializing %bcast 55 // for its use in %concat_2. 56 std::unique_ptr<HloComputation> MakeRematerializableComputation( 57 const string& suffix = "") { 58 auto builder = HloComputation::Builder(TestName() + suffix); 59 auto param = builder.AddInstruction( 60 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 61 auto bcast = builder.AddInstruction( 62 HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); 63 auto negate = builder.AddInstruction( 64 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); 65 auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( 66 ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate}, 67 /*dimension=*/0)); 68 auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( 69 vec1_shape_, concat_1, /*start_indices=*/{0}, 70 /*limit_indices=*/{1}, 71 /*strides=*/{1})); 72 auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( 73 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, 74 /*dimension=*/0)); 75 // Add a final slice to make the parameter shape match the output shape 76 // which is necessary to use this computation in a while. 77 builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, 78 /*start_indices=*/{0}, 79 /*limit_indices=*/{1}, 80 /*strides=*/{1})); 81 return builder.Build(); 82 } 83 84 // Creates and returns a computation which includes a while and can benefit 85 // from rematerialization. The computation looks like: 86 // 87 // F32[] %param = {...} 88 // F32[1024] %bcast = broadcast(%param) 89 // F32[1] %slice_1 = slice(%bcast, {0:1}) 90 // F32[1] %while = while(%slice_1, while_body, while_cond) 91 // F32[1025] %concat = concat({%bcast, %while}) 92 // F32[1] %slice_2 = slice(%concat, {0:1}); 93 // 94 // The instruction %bcast can be rematerialized before its use at %concat to 95 // reduce peak memory usage. This avoids %bcast being live during execution of 96 // the while. Peak memory use is maximum of 8K and 4K plus the memory use of 97 // the while subcomputations. 98 std::unique_ptr<HloComputation> MakeRematerializableWhileComputation( 99 HloComputation* while_cond, HloComputation* while_body, 100 const string& suffix = "") { 101 auto builder = HloComputation::Builder(TestName() + suffix); 102 auto param = builder.AddInstruction( 103 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 104 auto bcast = builder.AddInstruction( 105 HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); 106 auto slice_1 = builder.AddInstruction( 107 HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, 108 /*limit_indices=*/{1}, 109 /*strides=*/{1})); 110 auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( 111 vec1_shape_, while_cond, while_body, slice_1)); 112 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( 113 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst}, 114 /*dimension=*/0)); 115 builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, 116 /*start_indices=*/{0}, 117 /*limit_indices=*/{1}, 118 /*strides=*/{1})); 119 return builder.Build(); 120 } 121 122 // Create and return a trivial computation appropriate for use as a while 123 // condition. 124 std::unique_ptr<HloComputation> MakeConditionComputation() { 125 auto builder = HloComputation::Builder(TestName() + ".cond"); 126 builder.AddInstruction( 127 HloInstruction::CreateParameter(0, vec1_shape_, "param")); 128 builder.AddInstruction( 129 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 130 return builder.Build(); 131 } 132 133 // Return the byte size of the top-level buffer of the given shape. 134 static int64 ByteSizeOf(const Shape& shape) { 135 return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 136 } 137 138 // Various shapes used in the canned computations. 139 const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); 140 const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); 141 const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); 142 }; 143 144 // Test rematerialization of a single computation produced by 145 // MakeRematerializableComputation. 146 TEST_F(HloRematerializationTest, SingleComputation) { 147 auto module = CreateNewModule(); 148 HloComputation* computation = 149 module->AddEntryComputation(MakeRematerializableComputation()); 150 151 // Find and save the original broadcast instruction which should be 152 // rematerialized. 153 const HloInstruction* slice = computation->root_instruction(); 154 ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _))); 155 const HloInstruction* concat = slice->operand(0); 156 const HloInstruction* bcast = concat->operand(0); 157 158 SequentialHloOrdering::HloModuleSequence sequence; 159 // Computation requires 16KB without rematerialization, but uses only 12KB 160 // with rematerialization so pick a memory limit between these values (14KB). 161 TF_ASSERT_OK_AND_ASSIGN(bool changed, 162 HloRematerialization::RematerializeAndSchedule( 163 ByteSizeOf, 164 /*memory_limit_bytes=*/14 * 1024, module.get(), 165 SchedulerAlgorithm::kAuto, &sequence)); 166 EXPECT_TRUE(changed); 167 168 // Root should not have changed. 169 EXPECT_EQ(computation->root_instruction(), slice); 170 171 // The broadcast should have been rematerialized. 172 const HloInstruction* remat_bcast = concat->operand(0); 173 EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast))); 174 175 // The rematerialized broadcast should be immediate before the concat in the 176 // sequence. 177 EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], 178 concat); 179 EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], 180 remat_bcast); 181 } 182 183 // Test rematerialization of a single computation produced by 184 // MakeRematerializableComputation but with a sufficiently high memory limit 185 // such that no instructions are rematerialized. 186 TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { 187 auto module = CreateNewModule(); 188 HloComputation* computation = 189 module->AddEntryComputation(MakeRematerializableComputation()); 190 191 EXPECT_EQ(computation->instruction_count(), 7); 192 193 SequentialHloOrdering::HloModuleSequence sequence; 194 TF_ASSERT_OK_AND_ASSIGN(bool changed, 195 HloRematerialization::RematerializeAndSchedule( 196 ByteSizeOf, 197 /*memory_limit_bytes=*/20 * 1024, module.get(), 198 SchedulerAlgorithm::kAuto, &sequence)); 199 200 // No instructions should have been materialized. 201 EXPECT_FALSE(changed); 202 EXPECT_EQ(computation->instruction_count(), 7); 203 } 204 205 // Test rematerialization of a computation which calls another computation via a 206 // while. Both the entry computation and while body computation can have memory 207 // usage reduced via rematerialization however the memory limit is set such that 208 // only one computation needs to have an instruction rematerialized. The entry 209 // computation should be the one chosen because rematerialization in the while 210 // will presumably be more expensive. 211 TEST_F(HloRematerializationTest, RematerializeAroundWhile) { 212 auto module = CreateNewModule(); 213 214 auto cond_builder = HloComputation::Builder(TestName() + ".cond"); 215 cond_builder.AddInstruction( 216 HloInstruction::CreateParameter(0, vec1_shape_, "param")); 217 cond_builder.AddInstruction( 218 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 219 HloComputation* while_cond = 220 module->AddEmbeddedComputation(cond_builder.Build()); 221 222 HloComputation* body_computation = module->AddEmbeddedComputation( 223 MakeRematerializableComputation(/*suffix=*/".body")); 224 HloComputation* entry_computation = 225 module->AddEntryComputation(MakeRematerializableWhileComputation( 226 while_cond, /*while_body=*/body_computation)); 227 228 EXPECT_EQ(entry_computation->instruction_count(), 6); 229 EXPECT_EQ(body_computation->instruction_count(), 7); 230 231 // The body computation uses 16KB and the entry computation uses 2KB at the 232 // while so the peak memory use of the module is 18KB. Set the memory limit a 233 // bit lower (17KB) to force rematerialization of the entry computation. 234 SequentialHloOrdering::HloModuleSequence sequence; 235 TF_ASSERT_OK_AND_ASSIGN(bool changed, 236 HloRematerialization::RematerializeAndSchedule( 237 ByteSizeOf, 238 /*memory_limit_bytes=*/17 * 1024, module.get(), 239 SchedulerAlgorithm::kAuto, &sequence)); 240 EXPECT_TRUE(changed); 241 242 // Only the entry computation should have a rematerialized instruction added. 243 EXPECT_EQ(entry_computation->instruction_count(), 7); 244 EXPECT_EQ(body_computation->instruction_count(), 7); 245 } 246 247 // Test rematerialization of a computation which calls another computation via a 248 // while. Both the entry computation and while body computation should have 249 // computations rematerialized. 250 TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { 251 auto module = CreateNewModule(); 252 253 auto cond_builder = HloComputation::Builder(TestName() + ".cond"); 254 cond_builder.AddInstruction( 255 HloInstruction::CreateParameter(0, vec1_shape_, "param")); 256 cond_builder.AddInstruction( 257 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 258 HloComputation* while_cond = 259 module->AddEmbeddedComputation(cond_builder.Build()); 260 261 HloComputation* body_computation = module->AddEmbeddedComputation( 262 MakeRematerializableComputation(/*suffix=*/".body")); 263 HloComputation* entry_computation = 264 module->AddEntryComputation(MakeRematerializableWhileComputation( 265 while_cond, /*while_body=*/body_computation)); 266 267 EXPECT_EQ(entry_computation->instruction_count(), 6); 268 EXPECT_EQ(body_computation->instruction_count(), 7); 269 270 SequentialHloOrdering::HloModuleSequence sequence; 271 TF_ASSERT_OK_AND_ASSIGN(bool changed, 272 HloRematerialization::RematerializeAndSchedule( 273 ByteSizeOf, 274 /*memory_limit_bytes=*/15 * 1024, module.get(), 275 SchedulerAlgorithm::kAuto, &sequence)); 276 EXPECT_TRUE(changed); 277 278 // Both computations should have a rematerialized instruction added. 279 EXPECT_EQ(entry_computation->instruction_count(), 7); 280 EXPECT_EQ(body_computation->instruction_count(), 8); 281 } 282 283 // Test rematerialization of a doubly nested computation. All computations 284 // should have an instruction rematerialized. 285 TEST_F(HloRematerializationTest, RematerializeNestedComputations) { 286 auto module = CreateNewModule(); 287 288 auto cond_builder = HloComputation::Builder(TestName() + ".cond"); 289 cond_builder.AddInstruction( 290 HloInstruction::CreateParameter(0, vec1_shape_, "param")); 291 cond_builder.AddInstruction( 292 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 293 HloComputation* while_cond = 294 module->AddEmbeddedComputation(cond_builder.Build()); 295 296 HloComputation* inner_computation = module->AddEmbeddedComputation( 297 MakeRematerializableComputation(/*suffix=*/".inner")); 298 HloComputation* middle_computation = 299 module->AddEmbeddedComputation(MakeRematerializableWhileComputation( 300 while_cond, /*while_body=*/inner_computation, 301 /*suffix=*/".middle")); 302 HloComputation* entry_computation = 303 module->AddEntryComputation(MakeRematerializableWhileComputation( 304 while_cond, /*while_body=*/middle_computation)); 305 306 EXPECT_EQ(entry_computation->instruction_count(), 6); 307 EXPECT_EQ(middle_computation->instruction_count(), 6); 308 EXPECT_EQ(inner_computation->instruction_count(), 7); 309 310 // If all computations are maximally rematerialized then peak memory usage is 311 // ~12K so pick something slightly larger. 312 SequentialHloOrdering::HloModuleSequence sequence; 313 TF_ASSERT_OK_AND_ASSIGN(bool changed, 314 HloRematerialization::RematerializeAndSchedule( 315 ByteSizeOf, 316 /*memory_limit_bytes=*/13 * 1024, module.get(), 317 SchedulerAlgorithm::kAuto, &sequence)); 318 EXPECT_TRUE(changed); 319 320 // All computations should have a rematerialized instruction added. 321 EXPECT_EQ(entry_computation->instruction_count(), 7); 322 EXPECT_EQ(middle_computation->instruction_count(), 7); 323 EXPECT_EQ(inner_computation->instruction_count(), 8); 324 } 325 326 TEST_F(HloRematerializationTest, RngNotRematerialized) { 327 // Test that a single rng is not rematerialized: 328 // 329 // Entry computation: 330 // F32[] %param = {...} 331 // F32[1024] rng = rng(param) 332 // F32[1024] tanh = tanh(rng) 333 // F32[1024] exp = exp(rng) 334 // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + 335 // // tanh + exp 336 // 337 // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + 338 // // rng + tanh + exp 339 // 340 // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + 341 // // rng + tanh + exp 342 auto module = CreateNewModule(); 343 344 auto builder = HloComputation::Builder(TestName()); 345 auto param = builder.AddInstruction( 346 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 347 auto rng = builder.AddInstruction(HloInstruction::CreateRng( 348 vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param})); 349 auto tanh = builder.AddInstruction( 350 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); 351 auto exp = builder.AddInstruction( 352 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); 353 auto add_0 = builder.AddInstruction( 354 HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); 355 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( 356 vec1024_shape_, HloOpcode::kAdd, rng, 357 builder.AddInstruction(HloInstruction::CreateBinary( 358 vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); 359 builder.AddInstruction(HloInstruction::CreateBinary( 360 vec1024_shape_, HloOpcode::kAdd, rng, 361 builder.AddInstruction(HloInstruction::CreateBinary( 362 vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); 363 HloComputation* entry_computation = 364 module->AddEntryComputation(builder.Build()); 365 366 auto count_rngs = [](const HloComputation* computation) { 367 int64 rng_count = 0; 368 for (auto* instruction : computation->instructions()) { 369 if (instruction->opcode() == HloOpcode::kRng) { 370 ++rng_count; 371 } 372 } 373 return rng_count; 374 }; 375 // Before rematerialization there should be a single broadcast rng in 376 // the graph. 377 ASSERT_EQ(count_rngs(entry_computation), 1); 378 const int64 original_instruction_count = 379 entry_computation->instruction_count(); 380 SequentialHloOrdering::HloModuleSequence sequence; 381 // Pick a memory limit some where between 24KB (initial peak memory including 382 // parameter and output) and 20KB (peak memory possible with 383 // rematerialization). 384 TF_ASSERT_OK_AND_ASSIGN( 385 bool changed, HloRematerialization::RematerializeAndSchedule( 386 ByteSizeOf, 387 /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), 388 module.get(), SchedulerAlgorithm::kAuto, &sequence)); 389 EXPECT_TRUE(changed); 390 // The rng should not have been rematerialized. 391 EXPECT_EQ(count_rngs(entry_computation), 1); 392 // There should have been rematerialization. 393 EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); 394 } 395 396 TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { 397 // Test that a single instruction is rematerialized several times. Module: 398 // 399 // Entry computation: 400 // F32[] %param = {...} 401 // F32[1024] %bcast = broadcast(%param) 402 // F32[1024] %add_1 = add(%bcast, bcast) 403 // F32[1024] %call_1 = call(Subcomputation, {%add_1}) 404 // F32[1024] %add_2 = add(%bcast, call_1) 405 // F32[1024] %call_2 = call(SubComputation, {%add_2}) 406 // F32[1024] %add_3 = add(%bcast, call_2) 407 // F32[1024] %call_3 = call(Subcomputation, {%add_3}) 408 // F32[1024] %add_4 = add(%bcast, call_3) 409 // 410 // Subcomputation: 411 // F32[1024] %param = {...} 412 // F32[2048] %concat = concat({%param, %param}) 413 // F32[1024] %slice = slice(%concat) 414 // 415 // The value %bcast is live across each call of Subcomputation (which requires 416 // 8KB) though the value is not used in the calls. Rematerializing %bcast 417 // across these calls reduces peak memory use from ~20KB down to ~16KB. 418 auto module = CreateNewModule(); 419 420 HloComputation* subcomputation = nullptr; 421 { 422 auto builder = HloComputation::Builder(TestName() + ".subcomputation"); 423 auto param = builder.AddInstruction( 424 HloInstruction::CreateParameter(0, vec1024_shape_, "param")); 425 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( 426 ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, 427 /*dimension=*/0)); 428 builder.AddInstruction(HloInstruction::CreateSlice( 429 vec1024_shape_, concat, /*start_indices=*/{0}, 430 /*limit_indices=*/{1024}, /*strides=*/{1})); 431 subcomputation = module->AddEmbeddedComputation(builder.Build()); 432 } 433 434 auto builder = HloComputation::Builder(TestName()); 435 auto param = builder.AddInstruction( 436 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 437 auto bcast = builder.AddInstruction( 438 HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); 439 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( 440 vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); 441 auto call_1 = builder.AddInstruction( 442 HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); 443 auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( 444 vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); 445 auto call_2 = builder.AddInstruction( 446 HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation)); 447 auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary( 448 vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); 449 auto call_3 = builder.AddInstruction( 450 HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation)); 451 auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary( 452 vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); 453 HloComputation* entry_computation = 454 module->AddEntryComputation(builder.Build()); 455 456 auto count_broadcasts = [](const HloComputation* computation) { 457 int64 bcast_count = 0; 458 for (auto* instruction : computation->instructions()) { 459 if (instruction->opcode() == HloOpcode::kBroadcast) { 460 bcast_count++; 461 } 462 } 463 return bcast_count; 464 }; 465 466 // Before rematerialization there should be a single broadcast instruction in 467 // the graph. 468 EXPECT_EQ(count_broadcasts(entry_computation), 1); 469 EXPECT_EQ(entry_computation->instruction_count(), 9); 470 471 EXPECT_EQ(add_2->operand(0), bcast); 472 EXPECT_EQ(add_3->operand(0), bcast); 473 EXPECT_EQ(add_4->operand(0), bcast); 474 475 SequentialHloOrdering::HloModuleSequence sequence; 476 // Pick a memory limit some where between 24KB (initial peak memory including 477 // parameter and output) and 20KB (peak memory possible with 478 // rematerialization). 479 TF_ASSERT_OK_AND_ASSIGN(bool changed, 480 HloRematerialization::RematerializeAndSchedule( 481 ByteSizeOf, 482 /*memory_limit_bytes=*/22 * 1024, module.get(), 483 SchedulerAlgorithm::kAuto, &sequence)); 484 EXPECT_TRUE(changed); 485 486 // The broadcast should have been rematerialized 3 times. 487 EXPECT_EQ(count_broadcasts(entry_computation), 4); 488 EXPECT_EQ(entry_computation->instruction_count(), 12); 489 490 // The operands of add_2, add_3, and add_4 should all be rematerialized 491 // broadcasts. 492 EXPECT_NE(add_2->operand(0), bcast); 493 EXPECT_THAT(add_2->operand(0), op::Broadcast(param)); 494 EXPECT_NE(add_3->operand(0), bcast); 495 EXPECT_THAT(add_3->operand(0), op::Broadcast(param)); 496 EXPECT_NE(add_4->operand(0), bcast); 497 EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); 498 } 499 500 class IndirectUseTest : public HloRematerializationTest, 501 public ::testing::WithParamInterface<bool> {}; 502 503 TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { 504 // Test that an rematerializable instruction is not rematerialized if it has 505 // an indirect use. Test is parameterized on whether the value has an indirect 506 // use, and the instruction should be rematerialized iff the value has no 507 // indirect use. Module: 508 // 509 // Entry computation: 510 // F32[] %param = {...} 511 // F32[1024] %bcast = broadcast(%param) 512 // F32[1024] %add_1 = add(%bcast, bcast) 513 // F32[1024] %call = call(Subcomputation, {%add_1}) 514 // F32[1024] %add_2 = add(%bcast, call) 515 // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2) 516 // F32[1024] %gte = GetTupleElememt(%tuple, 0) 517 // F32[1024] %negate = negate(%gte) 518 // 519 // Subcomputation: 520 // F32[1024] %param = {...} 521 // F32[2048] %concat = concat({%param, %param}) 522 // F32[1024] %slice = slice(%concat) 523 // 524 // The value %bcast is live across the call and rematerialization of %bcast 525 // across that point would reduce peak memory use by 4KB. However, %bcast is 526 // used indirectly in the %negate so rematerialization should not happen. 527 // 528 // This test is parameterized on whether the broadcast has an indirect use or 529 // not. The indirect use is controlled by the index of the GetTupleElement 530 // instruction. If the element is 0, then the %negate operand aliases %bcast 531 // (ie %bcast is used indirectly by %negate), otherwise the %negate operand 532 // aliases %add_2. 533 const bool indirectly_used = GetParam(); 534 auto module = CreateNewModule(); 535 536 HloComputation* subcomputation = nullptr; 537 { 538 auto builder = HloComputation::Builder(TestName() + ".subcomputation"); 539 auto param = builder.AddInstruction( 540 HloInstruction::CreateParameter(0, vec1024_shape_, "param")); 541 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( 542 ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, 543 /*dimension=*/0)); 544 builder.AddInstruction(HloInstruction::CreateSlice( 545 vec1024_shape_, concat, /*start_indices=*/{0}, 546 /*limit_indices=*/{1024}, /*strides=*/{1})); 547 subcomputation = module->AddEmbeddedComputation(builder.Build()); 548 } 549 550 auto builder = HloComputation::Builder(TestName()); 551 auto param = builder.AddInstruction( 552 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 553 auto bcast = builder.AddInstruction( 554 HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); 555 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( 556 vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); 557 auto call_1 = builder.AddInstruction( 558 HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); 559 auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( 560 vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); 561 auto tuple = 562 builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2})); 563 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 564 vec1024_shape_, tuple, indirectly_used ? 0 : 1)); 565 builder.AddInstruction( 566 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte)); 567 HloComputation* entry_computation = 568 module->AddEntryComputation(builder.Build()); 569 570 EXPECT_EQ(entry_computation->instruction_count(), 8); 571 572 SequentialHloOrdering::HloModuleSequence sequence; 573 // Pick a memory limit some where between 24KB (initial peak memory including 574 // parameter and output) and 20KB (peak memory possible with 575 // rematerialization). 576 TF_ASSERT_OK_AND_ASSIGN(bool changed, 577 HloRematerialization::RematerializeAndSchedule( 578 ByteSizeOf, 579 /*memory_limit_bytes=*/22 * 1024, module.get(), 580 SchedulerAlgorithm::kAuto, &sequence)); 581 // Rematerialization should only occur if the rematerializable instruction has 582 // no indirect uses. 583 if (indirectly_used) { 584 EXPECT_FALSE(changed); 585 EXPECT_EQ(entry_computation->instruction_count(), 8); 586 } else { 587 EXPECT_TRUE(changed); 588 EXPECT_EQ(entry_computation->instruction_count(), 9); 589 } 590 } 591 592 INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, 593 ::testing::Values(true, false)); 594 595 } // namespace 596 597 } // namespace xla 598