Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     19 #include "tensorflow/compiler/xla/test.h"
     20 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
     21 #include "tensorflow/core/lib/core/status_test_util.h"
     22 
     23 namespace xla {
     24 namespace {
     25 
     26 namespace op = xla::testing::opcode_matchers;
     27 
     28 class WhileLoopSimplifierTest : public HloVerifiedTestBase {
     29  public:
     30   // Makes a computation that contains a loop that runs num_iters times.
     31   HloComputation* MakeSimpleLoop(int num_iters, HloModule* module);
     32 
     33   // Makes a computation which has one parameter, of the given shape, and always
     34   // returns PRED[]{true}.  This is useful as a dummy loop condition.
     35   HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
     36                                             HloModule* module);
     37 };
     38 
     39 HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters,
     40                                                         HloModule* module) {
     41   HloComputation::Builder builder(TestName());
     42 
     43   auto loop_iter_init = builder.AddInstruction(
     44       HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
     45   auto loop_data_init = builder.AddInstruction(
     46       HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1, 2})));
     47   auto loop_init = builder.AddInstruction(
     48       HloInstruction::CreateTuple({loop_iter_init, loop_data_init}));
     49 
     50   HloComputation* condition;
     51   {
     52     HloComputation::Builder cond_builder(TestName() + ".condition");
     53     auto loop_var = cond_builder.AddInstruction(
     54         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
     55     auto loop_induction_var =
     56         cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     57             ShapeUtil::MakeShape(S32, {}), loop_var, 0));
     58     auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant(
     59         Literal::CreateR0<int32>(42 + num_iters)));
     60     cond_builder.AddInstruction(HloInstruction::CreateBinary(
     61         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var,
     62         limit));
     63     condition = module->AddEmbeddedComputation(cond_builder.Build());
     64   }
     65 
     66   HloComputation* body;
     67   {
     68     HloComputation::Builder body_builder(TestName() + ".body");
     69     auto loop_var = body_builder.AddInstruction(
     70         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
     71     auto loop_induction_var =
     72         body_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     73             ShapeUtil::MakeShape(S32, {}), loop_var, 0));
     74     auto new_loop_induction_var =
     75         body_builder.AddInstruction(HloInstruction::CreateBinary(
     76             loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var,
     77             body_builder.AddInstruction(
     78                 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)))));
     79     auto loop_data =
     80         body_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     81             loop_data_init->shape(), loop_var, 1));
     82     auto new_loop_data =
     83         body_builder.AddInstruction(HloInstruction::CreateBinary(
     84             loop_data_init->shape(), HloOpcode::kMultiply, loop_data,
     85             loop_data));
     86     body_builder.AddInstruction(
     87         HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data}));
     88     body = module->AddEmbeddedComputation(body_builder.Build());
     89   }
     90 
     91   builder.AddInstruction(HloInstruction::CreateWhile(
     92       loop_init->shape(), condition, body, loop_init));
     93 
     94   return module->AddEntryComputation(builder.Build());
     95 }
     96 
     97 HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation(
     98     const Shape& param_shape, HloModule* module) {
     99   HloComputation::Builder builder(TestName() + ".always_true");
    100   builder.AddInstruction(
    101       HloInstruction::CreateParameter(0, param_shape, "param"));
    102   builder.AddInstruction(
    103       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    104   return module->AddEmbeddedComputation(builder.Build());
    105 }
    106 
    107 TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) {
    108   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module());
    109   ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    110   EXPECT_THAT(computation->root_instruction(),
    111               op::Tuple(op::Constant(), op::Constant()));
    112 }
    113 
    114 TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) {
    115   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
    116   ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    117   EXPECT_THAT(computation->root_instruction(),
    118               op::Tuple(op::Add(), op::Multiply()));
    119 }
    120 
    121 TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) {
    122   MakeSimpleLoop(/*num_iters=*/2, &module());
    123   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    124 }
    125 
    126 TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) {
    127   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
    128   auto* while_op = computation->root_instruction();
    129   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
    130   auto* true_op = while_op->while_body()->AddInstruction(
    131       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    132   TF_ASSERT_OK(true_op->AddControlDependencyTo(
    133       while_op->while_body()->root_instruction()));
    134   ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    135   EXPECT_THAT(computation->root_instruction()->control_predecessors(),
    136               ElementsAre(op::Constant()))
    137       << computation->ToString();
    138 }
    139 
    140 // Loops that contain send/recv nodes can't be simplified; the loop structure
    141 // around send/recv nodes must be preserved.
    142 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
    143   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
    144   auto* while_op = computation->root_instruction();
    145   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
    146   auto* while_body = while_op->while_body();
    147   auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
    148       while_body->AddInstruction(
    149           HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
    150       /*channel_id=*/0));
    151   while_body->AddInstruction(HloInstruction::CreateSendDone(send));
    152   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    153 }
    154 
    155 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
    156   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
    157   auto* while_op = computation->root_instruction();
    158   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
    159   auto* while_body = while_op->while_body();
    160   auto* recv = while_body->AddInstruction(
    161       HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
    162                                  /*channel_id=*/0));
    163   while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
    164   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    165 }
    166 
    167 // The limitation on not being able to simplify loops that contain infeeds (and
    168 // other non-removable instructions) isn't fundamental -- it just stems from the
    169 // fact that our infrastructure sees simplifying such a loop as tantamount to
    170 // removing the non-removable instruction.
    171 TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
    172   HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
    173   auto* while_op = computation->root_instruction();
    174   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
    175   auto* while_body = while_op->while_body();
    176   while_body->AddInstruction(
    177       HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
    178   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    179 }
    180 
    181 // Check that we don't crash when given a loop whose shape is not a tuple.
    182 TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) {
    183   HloComputation::Builder builder(TestName());
    184   auto loop_init = builder.AddInstruction(
    185       HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
    186 
    187   HloComputation* condition;
    188   {
    189     HloComputation::Builder cond_builder(TestName() + ".condition");
    190     auto param = cond_builder.AddInstruction(
    191         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
    192     cond_builder.AddInstruction(HloInstruction::CreateBinary(
    193         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param,
    194         cond_builder.AddInstruction(
    195             HloInstruction::CreateConstant(Literal::CreateR0<int32>(100)))));
    196     condition = module().AddEmbeddedComputation(cond_builder.Build());
    197   }
    198 
    199   HloComputation* body;
    200   {
    201     HloComputation::Builder body_builder(TestName() + ".body");
    202     auto param = body_builder.AddInstruction(
    203         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
    204     body_builder.AddInstruction(HloInstruction::CreateBinary(
    205         ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param,
    206         body_builder.AddInstruction(
    207             HloInstruction::CreateConstant(Literal::CreateR0<int32>(-1)))));
    208     body = module().AddEmbeddedComputation(body_builder.Build());
    209   }
    210 
    211   builder.AddInstruction(HloInstruction::CreateWhile(
    212       loop_init->shape(), condition, body, loop_init));
    213 
    214   module().AddEntryComputation(builder.Build());
    215   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    216 }
    217 
    218 // Construct a loop where we swap the tuple elements in each iteration.
    219 // Although the tuple elements aren't used in the loop, we don't eliminate them,
    220 // because the swapping side-effect is visible to users of the loop.
    221 TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) {
    222   HloComputation::Builder builder(TestName());
    223   auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
    224       builder.AddInstruction(
    225           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    226       builder.AddInstruction(
    227           HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
    228   }));
    229 
    230   HloComputation* condition =
    231       MakeAlwaysTrueComputation(loop_init->shape(), &module());
    232   HloComputation* body;
    233   {
    234     HloComputation::Builder body_builder(TestName() + ".body");
    235     auto param = body_builder.AddInstruction(
    236         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
    237     auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    238     body_builder.AddInstruction(HloInstruction::CreateTuple({
    239         body_builder.AddInstruction(
    240             HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)),
    241         body_builder.AddInstruction(
    242             HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)),
    243     }));
    244     body = module().AddEmbeddedComputation(body_builder.Build());
    245   }
    246 
    247   builder.AddInstruction(HloInstruction::CreateWhile(
    248       loop_init->shape(), condition, body, loop_init));
    249 
    250   module().AddEntryComputation(builder.Build());
    251   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    252 }
    253 
    254 // Construct a loop where we assign a constant to tuple element 0 in each
    255 // iteration.  We can't eliminate tuple element 0, even though we never use its
    256 // value.
    257 TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) {
    258   HloComputation::Builder builder(TestName());
    259   auto loop_init = builder.AddInstruction(
    260       HloInstruction::CreateTuple({builder.AddInstruction(
    261           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)))}));
    262 
    263   HloComputation* condition =
    264       MakeAlwaysTrueComputation(loop_init->shape(), &module());
    265   HloComputation* body;
    266   {
    267     HloComputation::Builder body_builder(TestName() + ".body");
    268     body_builder.AddInstruction(
    269         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
    270     body_builder.AddInstruction(HloInstruction::CreateTuple({
    271         body_builder.AddInstruction(
    272             HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
    273     }));
    274     body = module().AddEmbeddedComputation(body_builder.Build());
    275   }
    276 
    277   builder.AddInstruction(HloInstruction::CreateWhile(
    278       loop_init->shape(), condition, body, loop_init));
    279 
    280   module().AddEntryComputation(builder.Build());
    281   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    282 }
    283 
    284 // Nothing to simplify in a while loop whose tuple has 0 elements.
    285 TEST_F(WhileLoopSimplifierTest, EmptyTuple) {
    286   HloComputation::Builder builder(TestName());
    287   auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({}));
    288 
    289   HloComputation* condition =
    290       MakeAlwaysTrueComputation(loop_init->shape(), &module());
    291   HloComputation* body;
    292   {
    293     HloComputation::Builder body_builder(TestName() + ".body");
    294     body_builder.AddInstruction(
    295         HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
    296     body_builder.AddInstruction(HloInstruction::CreateTuple({}));
    297     body = module().AddEmbeddedComputation(body_builder.Build());
    298   }
    299 
    300   builder.AddInstruction(HloInstruction::CreateWhile(
    301       loop_init->shape(), condition, body, loop_init));
    302   module().AddEntryComputation(builder.Build());
    303   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    304 }
    305 
    306 // While loop where one tuple element is used twice in the body, and thus can't
    307 // be simplified away.
    308 TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) {
    309   HloComputation::Builder builder(TestName());
    310   auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
    311       builder.AddInstruction(
    312           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    313       builder.AddInstruction(
    314           HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
    315   }));
    316 
    317   HloComputation* condition =
    318       MakeAlwaysTrueComputation(loop_init->shape(), &module());
    319 
    320   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    321   HloComputation* body;
    322   {
    323     HloComputation::Builder body_builder(TestName() + ".body");
    324     auto* param = body_builder.AddInstruction(
    325         HloInstruction::CreateParameter(0, loop_init->shape(), "param0"));
    326     auto* gte0 = body_builder.AddInstruction(
    327         HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0));
    328     // get0 is used twice in the loop body's tuple.
    329     body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0}));
    330     body = module().AddEmbeddedComputation(body_builder.Build());
    331   }
    332 
    333   builder.AddInstruction(HloInstruction::CreateWhile(
    334       loop_init->shape(), condition, body, loop_init));
    335   module().AddEntryComputation(builder.Build());
    336   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    337 }
    338 
    339 // This while loop has three tuple elements.  Element 0 is unused and should be
    340 // removed. Element 1 is used by the loop body, and element 2 is used by the
    341 // loop condition; these two should stay.
    342 TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) {
    343   HloComputation::Builder builder(TestName());
    344   auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
    345       builder.AddInstruction(
    346           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    347       builder.AddInstruction(
    348           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    349       builder.AddInstruction(
    350           HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    351   }));
    352   auto loop_shape = loop_init->shape();
    353   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    354 
    355   HloComputation* condition;
    356   {
    357     HloComputation::Builder cond_builder(TestName() + ".loop_condition");
    358     auto param = cond_builder.AddInstruction(
    359         HloInstruction::CreateParameter(0, loop_shape, "param0"));
    360     cond_builder.AddInstruction(HloInstruction::CreateBinary(
    361         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq,
    362         cond_builder.AddInstruction(
    363             HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
    364         cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    365             scalar_s32, param, /*index=*/2))));
    366     condition = module().AddEmbeddedComputation(cond_builder.Build());
    367   }
    368 
    369   HloComputation* body;
    370   {
    371     HloComputation::Builder body_builder(TestName() + ".body");
    372     auto* param = body_builder.AddInstruction(
    373         HloInstruction::CreateParameter(0, loop_shape, "loop_var"));
    374 
    375     auto* tuple0 = body_builder.AddInstruction(
    376         HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0));
    377     auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary(
    378         scalar_s32, HloOpcode::kAdd,
    379         body_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    380             scalar_s32, param, /*index=*/1)),
    381         body_builder.AddInstruction(
    382             HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)))));
    383     auto* tuple2 = body_builder.AddInstruction(
    384         HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2));
    385     body_builder.AddInstruction(
    386         HloInstruction::CreateTuple({tuple0, tuple1, tuple2}));
    387 
    388     body = module().AddEmbeddedComputation(body_builder.Build());
    389   }
    390 
    391   auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
    392       loop_init->shape(), condition, body, loop_init));
    393 
    394   module().AddEntryComputation(builder.Build());
    395   EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
    396 
    397   // We leave most of the checking to HloVerifiedTestBase, which runs the
    398   // verifier on module() at the end of this test.
    399   HloInstruction* new_while_op = *std::find_if(
    400       module().entry_computation()->instructions().begin(),
    401       module().entry_computation()->instructions().end(),
    402       [&](const HloInstruction* instr) {
    403         return instr != while_op && instr->opcode() == HloOpcode::kWhile;
    404       });
    405   EXPECT_TRUE(
    406       ShapeUtil::Equal(new_while_op->shape(),
    407                        ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32})))
    408       << ShapeUtil::HumanString(new_while_op->shape());
    409   EXPECT_THAT(
    410       new_while_op->while_body()->root_instruction(),
    411       op::Tuple(
    412           op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0),
    413                   op::Constant()),
    414           op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
    415 
    416   EXPECT_THAT(new_while_op->while_condition()->root_instruction(),
    417               op::Eq(op::Constant(),
    418                      op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
    419 }
    420 
    421 TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) {
    422   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    423   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
    424 
    425   HloComputation* while_body = [&]() {
    426     HloComputation::Builder builder(TestName() + ".passthrough");
    427     HloInstruction* param = builder.AddInstruction(
    428         HloInstruction::CreateParameter(0, while_shape, "param"));
    429     HloComputation* result = module().AddEmbeddedComputation(builder.Build());
    430 
    431     result->AddInstruction(
    432         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    433     return result;
    434   }();
    435 
    436   HloComputation::Builder builder(TestName());
    437   auto* init_value = builder.AddInstruction(
    438       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    439   builder.AddInstruction(HloInstruction::CreateWhile(
    440       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    441       while_body, init_value));
    442   module().AddEntryComputation(builder.Build());
    443   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    444                           WhileLoopSimplifier{}.Run(&module()));
    445   EXPECT_FALSE(simplified_loop);
    446 }
    447 
    448 }  // namespace
    449 }  // namespace xla
    450