Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 
     31 namespace xla {
     32 namespace {
     33 
     34 namespace op = xla::testing::opcode_matchers;
     35 
     36 using ::testing::_;
     37 
     38 class HloRematerializationTest : public HloTestBase {
     39  protected:
     40   // Creates and returns a computation which can benefit from
     41   // rematerialization. The computation looks like:
     42   //
     43   //   F32[] %param = {...}
     44   //   F32[1024] %bcast = broadcast(%param)
     45   //   F32[1024] %negate = negate(%bcast)
     46   //   F32[2048] %concat_1 = concat({%negate, %negate})
     47   //   F32[1] %slice_1 = slice(%concat_1, {0:1})
     48   //   F32[1025] %concat_2 = concat({%bcast, %slice_1})
     49   //   F32[1] %slice_2 = slice(%concat_2, {0:1});
     50   //
     51   // The instruction %bcast can be rematerialized before its use at %concat_2
     52   // to reduce peak memory usage. This avoids %bcast and %concat_1 being
     53   // simultaneously live. Peak memory use is about 16KB before rematerialization
     54   // (during execution of %concat_1) and about 12KB after rematerializing %bcast
     55   // for its use in %concat_2.
     56   std::unique_ptr<HloComputation> MakeRematerializableComputation(
     57       const string& suffix = "") {
     58     auto builder = HloComputation::Builder(TestName() + suffix);
     59     auto param = builder.AddInstruction(
     60         HloInstruction::CreateParameter(0, scalar_shape_, "param"));
     61     auto bcast = builder.AddInstruction(
     62         HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
     63     auto negate = builder.AddInstruction(
     64         HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
     65     auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
     66         ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate},
     67         /*dimension=*/0));
     68     auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
     69         vec1_shape_, concat_1, /*start_indices=*/{0},
     70         /*limit_indices=*/{1},
     71         /*strides=*/{1}));
     72     auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
     73         ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
     74         /*dimension=*/0));
     75     // Add a final slice to make the parameter shape match the output shape
     76     // which is necessary to use this computation in a while.
     77     builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
     78                                                        /*start_indices=*/{0},
     79                                                        /*limit_indices=*/{1},
     80                                                        /*strides=*/{1}));
     81     return builder.Build();
     82   }
     83 
     84   // Creates and returns a computation which includes a while and can benefit
     85   // from rematerialization. The computation looks like:
     86   //
     87   //   F32[] %param = {...}
     88   //   F32[1024] %bcast = broadcast(%param)
     89   //   F32[1] %slice_1 = slice(%bcast, {0:1})
     90   //   F32[1] %while = while(%slice_1, while_body, while_cond)
     91   //   F32[1025] %concat = concat({%bcast, %while})
     92   //   F32[1] %slice_2 = slice(%concat, {0:1});
     93   //
     94   // The instruction %bcast can be rematerialized before its use at %concat to
     95   // reduce peak memory usage. This avoids %bcast being live during execution of
     96   // the while. Peak memory use is maximum of 8K and 4K plus the memory use of
     97   // the while subcomputations.
     98   std::unique_ptr<HloComputation> MakeRematerializableWhileComputation(
     99       HloComputation* while_cond, HloComputation* while_body,
    100       const string& suffix = "") {
    101     auto builder = HloComputation::Builder(TestName() + suffix);
    102     auto param = builder.AddInstruction(
    103         HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    104     auto bcast = builder.AddInstruction(
    105         HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
    106     auto slice_1 = builder.AddInstruction(
    107         HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
    108                                     /*limit_indices=*/{1},
    109                                     /*strides=*/{1}));
    110     auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
    111         vec1_shape_, while_cond, while_body, slice_1));
    112     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
    113         ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst},
    114         /*dimension=*/0));
    115     builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
    116                                                        /*start_indices=*/{0},
    117                                                        /*limit_indices=*/{1},
    118                                                        /*strides=*/{1}));
    119     return builder.Build();
    120   }
    121 
    122   // Create and return a trivial computation appropriate for use as a while
    123   // condition.
    124   std::unique_ptr<HloComputation> MakeConditionComputation() {
    125     auto builder = HloComputation::Builder(TestName() + ".cond");
    126     builder.AddInstruction(
    127         HloInstruction::CreateParameter(0, vec1_shape_, "param"));
    128     builder.AddInstruction(
    129         HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    130     return builder.Build();
    131   }
    132 
    133   // Return the byte size of the top-level buffer of the given shape.
    134   static int64 ByteSizeOf(const Shape& shape) {
    135     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
    136   }
    137 
    138   // Various shapes used in the canned computations.
    139   const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
    140   const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
    141   const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024});
    142 };
    143 
    144 // Test rematerialization of a single computation produced by
    145 // MakeRematerializableComputation.
    146 TEST_F(HloRematerializationTest, SingleComputation) {
    147   auto module = CreateNewModule();
    148   HloComputation* computation =
    149       module->AddEntryComputation(MakeRematerializableComputation());
    150 
    151   // Find and save the original broadcast instruction which should be
    152   // rematerialized.
    153   const HloInstruction* slice = computation->root_instruction();
    154   ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
    155   const HloInstruction* concat = slice->operand(0);
    156   const HloInstruction* bcast = concat->operand(0);
    157 
    158   SequentialHloOrdering::HloModuleSequence sequence;
    159   // Computation requires 16KB without rematerialization, but uses only 12KB
    160   // with rematerialization so pick a memory limit between these values (14KB).
    161   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    162                           HloRematerialization::RematerializeAndSchedule(
    163                               ByteSizeOf,
    164                               /*memory_limit_bytes=*/14 * 1024, module.get(),
    165                               SchedulerAlgorithm::kAuto, &sequence));
    166   EXPECT_TRUE(changed);
    167 
    168   // Root should not have changed.
    169   EXPECT_EQ(computation->root_instruction(), slice);
    170 
    171   // The broadcast should have been rematerialized.
    172   const HloInstruction* remat_bcast = concat->operand(0);
    173   EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
    174 
    175   // The rematerialized broadcast should be immediate before the concat in the
    176   // sequence.
    177   EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2],
    178             concat);
    179   EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3],
    180             remat_bcast);
    181 }
    182 
    183 // Test rematerialization of a single computation produced by
    184 // MakeRematerializableComputation but with a sufficiently high memory limit
    185 // such that no instructions are rematerialized.
    186 TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
    187   auto module = CreateNewModule();
    188   HloComputation* computation =
    189       module->AddEntryComputation(MakeRematerializableComputation());
    190 
    191   EXPECT_EQ(computation->instruction_count(), 7);
    192 
    193   SequentialHloOrdering::HloModuleSequence sequence;
    194   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    195                           HloRematerialization::RematerializeAndSchedule(
    196                               ByteSizeOf,
    197                               /*memory_limit_bytes=*/20 * 1024, module.get(),
    198                               SchedulerAlgorithm::kAuto, &sequence));
    199 
    200   // No instructions should have been materialized.
    201   EXPECT_FALSE(changed);
    202   EXPECT_EQ(computation->instruction_count(), 7);
    203 }
    204 
    205 // Test rematerialization of a computation which calls another computation via a
    206 // while. Both the entry computation and while body computation can have memory
    207 // usage reduced via rematerialization however the memory limit is set such that
    208 // only one computation needs to have an instruction rematerialized. The entry
    209 // computation should be the one chosen because rematerialization in the while
    210 // will presumably be more expensive.
    211 TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
    212   auto module = CreateNewModule();
    213 
    214   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
    215   cond_builder.AddInstruction(
    216       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
    217   cond_builder.AddInstruction(
    218       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    219   HloComputation* while_cond =
    220       module->AddEmbeddedComputation(cond_builder.Build());
    221 
    222   HloComputation* body_computation = module->AddEmbeddedComputation(
    223       MakeRematerializableComputation(/*suffix=*/".body"));
    224   HloComputation* entry_computation =
    225       module->AddEntryComputation(MakeRematerializableWhileComputation(
    226           while_cond, /*while_body=*/body_computation));
    227 
    228   EXPECT_EQ(entry_computation->instruction_count(), 6);
    229   EXPECT_EQ(body_computation->instruction_count(), 7);
    230 
    231   // The body computation uses 16KB and the entry computation uses 2KB at the
    232   // while so the peak memory use of the module is 18KB. Set the memory limit a
    233   // bit lower (17KB) to force rematerialization of the entry computation.
    234   SequentialHloOrdering::HloModuleSequence sequence;
    235   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    236                           HloRematerialization::RematerializeAndSchedule(
    237                               ByteSizeOf,
    238                               /*memory_limit_bytes=*/17 * 1024, module.get(),
    239                               SchedulerAlgorithm::kAuto, &sequence));
    240   EXPECT_TRUE(changed);
    241 
    242   // Only the entry computation should have a rematerialized instruction added.
    243   EXPECT_EQ(entry_computation->instruction_count(), 7);
    244   EXPECT_EQ(body_computation->instruction_count(), 7);
    245 }
    246 
    247 // Test rematerialization of a computation which calls another computation via a
    248 // while. Both the entry computation and while body computation should have
    249 // computations rematerialized.
    250 TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
    251   auto module = CreateNewModule();
    252 
    253   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
    254   cond_builder.AddInstruction(
    255       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
    256   cond_builder.AddInstruction(
    257       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    258   HloComputation* while_cond =
    259       module->AddEmbeddedComputation(cond_builder.Build());
    260 
    261   HloComputation* body_computation = module->AddEmbeddedComputation(
    262       MakeRematerializableComputation(/*suffix=*/".body"));
    263   HloComputation* entry_computation =
    264       module->AddEntryComputation(MakeRematerializableWhileComputation(
    265           while_cond, /*while_body=*/body_computation));
    266 
    267   EXPECT_EQ(entry_computation->instruction_count(), 6);
    268   EXPECT_EQ(body_computation->instruction_count(), 7);
    269 
    270   SequentialHloOrdering::HloModuleSequence sequence;
    271   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    272                           HloRematerialization::RematerializeAndSchedule(
    273                               ByteSizeOf,
    274                               /*memory_limit_bytes=*/15 * 1024, module.get(),
    275                               SchedulerAlgorithm::kAuto, &sequence));
    276   EXPECT_TRUE(changed);
    277 
    278   // Both computations should have a rematerialized instruction added.
    279   EXPECT_EQ(entry_computation->instruction_count(), 7);
    280   EXPECT_EQ(body_computation->instruction_count(), 8);
    281 }
    282 
    283 // Test rematerialization of a doubly nested computation. All computations
    284 // should have an instruction rematerialized.
    285 TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
    286   auto module = CreateNewModule();
    287 
    288   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
    289   cond_builder.AddInstruction(
    290       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
    291   cond_builder.AddInstruction(
    292       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    293   HloComputation* while_cond =
    294       module->AddEmbeddedComputation(cond_builder.Build());
    295 
    296   HloComputation* inner_computation = module->AddEmbeddedComputation(
    297       MakeRematerializableComputation(/*suffix=*/".inner"));
    298   HloComputation* middle_computation =
    299       module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
    300           while_cond, /*while_body=*/inner_computation,
    301           /*suffix=*/".middle"));
    302   HloComputation* entry_computation =
    303       module->AddEntryComputation(MakeRematerializableWhileComputation(
    304           while_cond, /*while_body=*/middle_computation));
    305 
    306   EXPECT_EQ(entry_computation->instruction_count(), 6);
    307   EXPECT_EQ(middle_computation->instruction_count(), 6);
    308   EXPECT_EQ(inner_computation->instruction_count(), 7);
    309 
    310   // If all computations are maximally rematerialized then peak memory usage is
    311   // ~12K so pick something slightly larger.
    312   SequentialHloOrdering::HloModuleSequence sequence;
    313   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    314                           HloRematerialization::RematerializeAndSchedule(
    315                               ByteSizeOf,
    316                               /*memory_limit_bytes=*/13 * 1024, module.get(),
    317                               SchedulerAlgorithm::kAuto, &sequence));
    318   EXPECT_TRUE(changed);
    319 
    320   // All computations should have a rematerialized instruction added.
    321   EXPECT_EQ(entry_computation->instruction_count(), 7);
    322   EXPECT_EQ(middle_computation->instruction_count(), 7);
    323   EXPECT_EQ(inner_computation->instruction_count(), 8);
    324 }
    325 
    326 TEST_F(HloRematerializationTest, RngNotRematerialized) {
    327   // Test that a single rng is not rematerialized:
    328   //
    329   // Entry computation:
    330   //   F32[] %param = {...}
    331   //   F32[1024] rng = rng(param)
    332   //   F32[1024] tanh = tanh(rng)
    333   //   F32[1024] exp = exp(rng)
    334   //   F32[1024] add_0 = add(rng, tanh)              // LIVE: add_0 + rng +
    335   //                                                 //       tanh + exp
    336   //
    337   //   F32[1024] add_1 = add(rng, add(exp, add_0))   // LIVE: add_1 + add_0 +
    338   //                                                 //       rng + tanh + exp
    339   //
    340   //   F32[1024] add_2 = add(rng, add(tanh, add_1))  // LIVE: add_2 + add_1 +
    341   //                                                 //       rng + tanh + exp
    342   auto module = CreateNewModule();
    343 
    344   auto builder = HloComputation::Builder(TestName());
    345   auto param = builder.AddInstruction(
    346       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    347   auto rng = builder.AddInstruction(HloInstruction::CreateRng(
    348       vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param}));
    349   auto tanh = builder.AddInstruction(
    350       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng));
    351   auto exp = builder.AddInstruction(
    352       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng));
    353   auto add_0 = builder.AddInstruction(
    354       HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh));
    355   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
    356       vec1024_shape_, HloOpcode::kAdd, rng,
    357       builder.AddInstruction(HloInstruction::CreateBinary(
    358           vec1024_shape_, HloOpcode::kAdd, exp, add_0))));
    359   builder.AddInstruction(HloInstruction::CreateBinary(
    360       vec1024_shape_, HloOpcode::kAdd, rng,
    361       builder.AddInstruction(HloInstruction::CreateBinary(
    362           vec1024_shape_, HloOpcode::kAdd, tanh, add_1))));
    363   HloComputation* entry_computation =
    364       module->AddEntryComputation(builder.Build());
    365 
    366   auto count_rngs = [](const HloComputation* computation) {
    367     int64 rng_count = 0;
    368     for (auto* instruction : computation->instructions()) {
    369       if (instruction->opcode() == HloOpcode::kRng) {
    370         ++rng_count;
    371       }
    372     }
    373     return rng_count;
    374   };
    375   // Before rematerialization there should be a single broadcast rng in
    376   // the graph.
    377   ASSERT_EQ(count_rngs(entry_computation), 1);
    378   const int64 original_instruction_count =
    379       entry_computation->instruction_count();
    380   SequentialHloOrdering::HloModuleSequence sequence;
    381   // Pick a memory limit some where between 24KB (initial peak memory including
    382   // parameter and output) and 20KB (peak memory possible with
    383   // rematerialization).
    384   TF_ASSERT_OK_AND_ASSIGN(
    385       bool changed, HloRematerialization::RematerializeAndSchedule(
    386                         ByteSizeOf,
    387                         /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
    388                         module.get(), SchedulerAlgorithm::kAuto, &sequence));
    389   EXPECT_TRUE(changed);
    390   // The rng should not have been rematerialized.
    391   EXPECT_EQ(count_rngs(entry_computation), 1);
    392   // There should have been rematerialization.
    393   EXPECT_GT(entry_computation->instruction_count(), original_instruction_count);
    394 }
    395 
    396 TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
    397   // Test that a single instruction is rematerialized several times. Module:
    398   //
    399   // Entry computation:
    400   //   F32[] %param = {...}
    401   //   F32[1024] %bcast = broadcast(%param)
    402   //   F32[1024] %add_1 = add(%bcast, bcast)
    403   //   F32[1024] %call_1 = call(Subcomputation, {%add_1})
    404   //   F32[1024] %add_2 = add(%bcast, call_1)
    405   //   F32[1024] %call_2 = call(SubComputation, {%add_2})
    406   //   F32[1024] %add_3 = add(%bcast, call_2)
    407   //   F32[1024] %call_3 = call(Subcomputation, {%add_3})
    408   //   F32[1024] %add_4 = add(%bcast, call_3)
    409   //
    410   // Subcomputation:
    411   //   F32[1024] %param = {...}
    412   //   F32[2048] %concat = concat({%param, %param})
    413   //   F32[1024] %slice = slice(%concat)
    414   //
    415   // The value %bcast is live across each call of Subcomputation (which requires
    416   // 8KB) though the value is not used in the calls. Rematerializing %bcast
    417   // across these calls reduces peak memory use from ~20KB down to ~16KB.
    418   auto module = CreateNewModule();
    419 
    420   HloComputation* subcomputation = nullptr;
    421   {
    422     auto builder = HloComputation::Builder(TestName() + ".subcomputation");
    423     auto param = builder.AddInstruction(
    424         HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
    425     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
    426         ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
    427         /*dimension=*/0));
    428     builder.AddInstruction(HloInstruction::CreateSlice(
    429         vec1024_shape_, concat, /*start_indices=*/{0},
    430         /*limit_indices=*/{1024}, /*strides=*/{1}));
    431     subcomputation = module->AddEmbeddedComputation(builder.Build());
    432   }
    433 
    434   auto builder = HloComputation::Builder(TestName());
    435   auto param = builder.AddInstruction(
    436       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    437   auto bcast = builder.AddInstruction(
    438       HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
    439   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
    440       vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
    441   auto call_1 = builder.AddInstruction(
    442       HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
    443   auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
    444       vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
    445   auto call_2 = builder.AddInstruction(
    446       HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation));
    447   auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary(
    448       vec1024_shape_, HloOpcode::kAdd, bcast, call_2));
    449   auto call_3 = builder.AddInstruction(
    450       HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation));
    451   auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
    452       vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
    453   HloComputation* entry_computation =
    454       module->AddEntryComputation(builder.Build());
    455 
    456   auto count_broadcasts = [](const HloComputation* computation) {
    457     int64 bcast_count = 0;
    458     for (auto* instruction : computation->instructions()) {
    459       if (instruction->opcode() == HloOpcode::kBroadcast) {
    460         bcast_count++;
    461       }
    462     }
    463     return bcast_count;
    464   };
    465 
    466   // Before rematerialization there should be a single broadcast instruction in
    467   // the graph.
    468   EXPECT_EQ(count_broadcasts(entry_computation), 1);
    469   EXPECT_EQ(entry_computation->instruction_count(), 9);
    470 
    471   EXPECT_EQ(add_2->operand(0), bcast);
    472   EXPECT_EQ(add_3->operand(0), bcast);
    473   EXPECT_EQ(add_4->operand(0), bcast);
    474 
    475   SequentialHloOrdering::HloModuleSequence sequence;
    476   // Pick a memory limit some where between 24KB (initial peak memory including
    477   // parameter and output) and 20KB (peak memory possible with
    478   // rematerialization).
    479   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    480                           HloRematerialization::RematerializeAndSchedule(
    481                               ByteSizeOf,
    482                               /*memory_limit_bytes=*/22 * 1024, module.get(),
    483                               SchedulerAlgorithm::kAuto, &sequence));
    484   EXPECT_TRUE(changed);
    485 
    486   // The broadcast should have been rematerialized 3 times.
    487   EXPECT_EQ(count_broadcasts(entry_computation), 4);
    488   EXPECT_EQ(entry_computation->instruction_count(), 12);
    489 
    490   // The operands of add_2, add_3, and add_4 should all be rematerialized
    491   // broadcasts.
    492   EXPECT_NE(add_2->operand(0), bcast);
    493   EXPECT_THAT(add_2->operand(0), op::Broadcast(param));
    494   EXPECT_NE(add_3->operand(0), bcast);
    495   EXPECT_THAT(add_3->operand(0), op::Broadcast(param));
    496   EXPECT_NE(add_4->operand(0), bcast);
    497   EXPECT_THAT(add_4->operand(0), op::Broadcast(param));
    498 }
    499 
    500 class IndirectUseTest : public HloRematerializationTest,
    501                         public ::testing::WithParamInterface<bool> {};
    502 
    503 TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
    504   // Test that an rematerializable instruction is not rematerialized if it has
    505   // an indirect use. Test is parameterized on whether the value has an indirect
    506   // use, and the instruction should be rematerialized iff the value has no
    507   // indirect use. Module:
    508   //
    509   // Entry computation:
    510   //   F32[] %param = {...}
    511   //   F32[1024] %bcast = broadcast(%param)
    512   //   F32[1024] %add_1 = add(%bcast, bcast)
    513   //   F32[1024] %call = call(Subcomputation, {%add_1})
    514   //   F32[1024] %add_2 = add(%bcast, call)
    515   //   {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2)
    516   //   F32[1024] %gte = GetTupleElememt(%tuple, 0)
    517   //   F32[1024] %negate = negate(%gte)
    518   //
    519   // Subcomputation:
    520   //   F32[1024] %param = {...}
    521   //   F32[2048] %concat = concat({%param, %param})
    522   //   F32[1024] %slice = slice(%concat)
    523   //
    524   // The value %bcast is live across the call and rematerialization of %bcast
    525   // across that point would reduce peak memory use by 4KB. However, %bcast is
    526   // used indirectly in the %negate so rematerialization should not happen.
    527   //
    528   // This test is parameterized on whether the broadcast has an indirect use or
    529   // not. The indirect use is controlled by the index of the GetTupleElement
    530   // instruction. If the element is 0, then the %negate operand aliases %bcast
    531   // (ie %bcast is used indirectly by %negate), otherwise the %negate operand
    532   // aliases %add_2.
    533   const bool indirectly_used = GetParam();
    534   auto module = CreateNewModule();
    535 
    536   HloComputation* subcomputation = nullptr;
    537   {
    538     auto builder = HloComputation::Builder(TestName() + ".subcomputation");
    539     auto param = builder.AddInstruction(
    540         HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
    541     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
    542         ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
    543         /*dimension=*/0));
    544     builder.AddInstruction(HloInstruction::CreateSlice(
    545         vec1024_shape_, concat, /*start_indices=*/{0},
    546         /*limit_indices=*/{1024}, /*strides=*/{1}));
    547     subcomputation = module->AddEmbeddedComputation(builder.Build());
    548   }
    549 
    550   auto builder = HloComputation::Builder(TestName());
    551   auto param = builder.AddInstruction(
    552       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    553   auto bcast = builder.AddInstruction(
    554       HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
    555   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
    556       vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
    557   auto call_1 = builder.AddInstruction(
    558       HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
    559   auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
    560       vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
    561   auto tuple =
    562       builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2}));
    563   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    564       vec1024_shape_, tuple, indirectly_used ? 0 : 1));
    565   builder.AddInstruction(
    566       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
    567   HloComputation* entry_computation =
    568       module->AddEntryComputation(builder.Build());
    569 
    570   EXPECT_EQ(entry_computation->instruction_count(), 8);
    571 
    572   SequentialHloOrdering::HloModuleSequence sequence;
    573   // Pick a memory limit some where between 24KB (initial peak memory including
    574   // parameter and output) and 20KB (peak memory possible with
    575   // rematerialization).
    576   TF_ASSERT_OK_AND_ASSIGN(bool changed,
    577                           HloRematerialization::RematerializeAndSchedule(
    578                               ByteSizeOf,
    579                               /*memory_limit_bytes=*/22 * 1024, module.get(),
    580                               SchedulerAlgorithm::kAuto, &sequence));
    581   // Rematerialization should only occur if the rematerializable instruction has
    582   // no indirect uses.
    583   if (indirectly_used) {
    584     EXPECT_FALSE(changed);
    585     EXPECT_EQ(entry_computation->instruction_count(), 8);
    586   } else {
    587     EXPECT_TRUE(changed);
    588     EXPECT_EQ(entry_computation->instruction_count(), 9);
    589   }
    590 }
    591 
    592 INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest,
    593                         ::testing::Values(true, false));
    594 
    595 }  // namespace
    596 
    597 }  // namespace xla
    598