Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     19 #include "tensorflow/compiler/xla/test.h"
     20 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
     21 #include "tensorflow/core/lib/core/status_test_util.h"
     22 
     23 namespace xla {
     24 namespace {
     25 
     26 namespace op = xla::testing::opcode_matchers;
     27 
     28 class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
     29  public:
     30   // Makes a computation which has one parameter, of the given shape, and always
     31   // returns PRED[]{true}.  This is useful as a dummy loop condition.
     32   HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
     33                                             HloModule* module);
     34 };
     35 
     36 static void FindOnlyWhileInstruction(HloComputation* computation,
     37                                      HloInstruction** while_instruction) {
     38   *while_instruction = nullptr;
     39   for (auto* instr : computation->instructions()) {
     40     if (instr->opcode() == HloOpcode::kWhile) {
     41       ASSERT_EQ(*while_instruction, nullptr);
     42       *while_instruction = instr;
     43     }
     44   }
     45 
     46   ASSERT_NE(*while_instruction, nullptr);
     47 }
     48 
     49 HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
     50     const Shape& param_shape, HloModule* module) {
     51   HloComputation::Builder builder(TestName() + ".always_true");
     52   builder.AddInstruction(
     53       HloInstruction::CreateParameter(0, param_shape, "param"));
     54   builder.AddInstruction(
     55       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
     56   return module->AddEmbeddedComputation(builder.Build());
     57 }
     58 
     59 TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) {
     60   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
     61   Shape while_shape =
     62       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
     63 
     64   HloComputation* while_body = [&]() {
     65     HloComputation::Builder builder(TestName() + ".while_body");
     66     HloInstruction* param = builder.AddInstruction(
     67         HloInstruction::CreateParameter(0, while_shape, "param"));
     68     HloInstruction* gte_0 = builder.AddInstruction(
     69         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
     70     HloInstruction* gte_1 = builder.AddInstruction(
     71         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
     72     HloInstruction* add_result =
     73         builder.AddInstruction(HloInstruction::CreateBinary(
     74             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
     75     builder.AddInstruction(
     76         HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
     77 
     78     return module().AddEmbeddedComputation(builder.Build());
     79   }();
     80 
     81   HloComputation::Builder builder(TestName());
     82   auto* init_value = builder.AddInstruction(
     83       HloInstruction::CreateParameter(0, while_shape, "init_value"));
     84   builder.AddInstruction(HloInstruction::CreateWhile(
     85       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
     86       while_body, init_value));
     87   HloComputation* entry_computation =
     88       module().AddEntryComputation(builder.Build());
     89   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
     90                           WhileLoopInvariantCodeMotion{}.Run(&module()));
     91   EXPECT_TRUE(simplified_loop);
     92 
     93   HloInstruction* transformed_while;
     94   FindOnlyWhileInstruction(entry_computation, &transformed_while);
     95 
     96   EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
     97   EXPECT_THAT(transformed_while->while_body()->instructions(),
     98               Each(Not(op::Add())));
     99 }
    100 
    101 TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
    102   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    103   Shape while_shape =
    104       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
    105 
    106   HloComputation* while_body = [&]() {
    107     HloComputation::Builder builder(TestName() + ".while_body");
    108     HloInstruction* param = builder.AddInstruction(
    109         HloInstruction::CreateParameter(0, while_shape, "param"));
    110     HloInstruction* gte_0 = builder.AddInstruction(
    111         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    112     HloInstruction* gte_1 = builder.AddInstruction(
    113         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    114     HloInstruction* gte_2_loop_variant = builder.AddInstruction(
    115         HloInstruction::CreateGetTupleElement(scalar_s32, param, 2));
    116 
    117     HloInstruction* add_result =
    118         builder.AddInstruction(HloInstruction::CreateBinary(
    119             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
    120     HloInstruction* mul_result =
    121         builder.AddInstruction(HloInstruction::CreateBinary(
    122             scalar_s32, HloOpcode::kMultiply, add_result, gte_1));
    123     HloInstruction* negate_result =
    124         builder.AddInstruction(HloInstruction::CreateUnary(
    125             scalar_s32, HloOpcode::kNegate, mul_result));
    126     HloInstruction* constant = builder.AddInstruction(
    127         HloInstruction::CreateConstant(Literal::CreateR0<int32>(4)));
    128     HloInstruction* sub_result =
    129         builder.AddInstruction(HloInstruction::CreateBinary(
    130             scalar_s32, HloOpcode::kSubtract, negate_result, constant));
    131     HloInstruction* divide_result =
    132         builder.AddInstruction(HloInstruction::CreateBinary(
    133             scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant));
    134     builder.AddInstruction(
    135         HloInstruction::CreateTuple({gte_0, gte_1, divide_result}));
    136 
    137     return module().AddEmbeddedComputation(builder.Build());
    138   }();
    139 
    140   HloComputation::Builder builder(TestName());
    141   auto* init_value = builder.AddInstruction(
    142       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    143   builder.AddInstruction(HloInstruction::CreateWhile(
    144       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    145       while_body, init_value));
    146   HloComputation* entry_computation =
    147       module().AddEntryComputation(builder.Build());
    148   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    149                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    150   EXPECT_TRUE(simplified_loop);
    151 
    152   HloInstruction* transformed_while;
    153   FindOnlyWhileInstruction(entry_computation, &transformed_while);
    154 
    155   EXPECT_THAT(entry_computation->instructions(),
    156               AllOf(Contains(op::Add()), Contains(op::Multiply()),
    157                     Contains(op::Negate()), Contains(op::Subtract()),
    158                     Contains(op::Constant()),
    159 
    160                     // The division had a loop varying operand so that better
    161                     // not be hoisted.
    162                     Not(Contains(op::Divide()))));
    163 
    164   EXPECT_THAT(transformed_while->while_body()->instructions(),
    165               Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(),
    166                              op::Subtract(), op::Constant()))));
    167 
    168   EXPECT_THAT(transformed_while->while_body()->instructions(),
    169               Contains(op::Divide()));
    170 }
    171 
    172 TEST_F(WhileLoopInvariantCodeMotionTest,
    173        DontHoistTriviallyLoopVaryingComputation) {
    174   // Basic negative test: the add expression is not loop invariant.
    175   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    176   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
    177 
    178   HloComputation* while_body = [&]() {
    179     HloComputation::Builder builder(TestName() + ".while_body");
    180     HloInstruction* param = builder.AddInstruction(
    181         HloInstruction::CreateParameter(0, while_shape, "param"));
    182     HloInstruction* gte_0 = builder.AddInstruction(
    183         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    184     HloInstruction* gte_1 = builder.AddInstruction(
    185         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    186     HloInstruction* add_result =
    187         builder.AddInstruction(HloInstruction::CreateBinary(
    188             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
    189     builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result}));
    190 
    191     return module().AddEmbeddedComputation(builder.Build());
    192   }();
    193 
    194   HloComputation::Builder builder(TestName());
    195   auto* init_value = builder.AddInstruction(
    196       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    197   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
    198       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    199       while_body, init_value));
    200 
    201   module().AddEntryComputation(builder.Build());
    202 
    203   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    204                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    205   EXPECT_FALSE(simplified_loop);
    206 
    207   EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
    208 }
    209 
    210 TEST_F(WhileLoopInvariantCodeMotionTest,
    211        DontHoistLoopVaryingComputationWithAlternatingTuples) {
    212   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    213   Shape while_shape =
    214       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
    215 
    216   HloComputation* while_body = [&]() {
    217     HloComputation::Builder builder(TestName() + ".while_body");
    218     HloInstruction* param = builder.AddInstruction(
    219         HloInstruction::CreateParameter(0, while_shape, "param"));
    220     HloInstruction* gte_0 = builder.AddInstruction(
    221         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    222     HloInstruction* gte_1 = builder.AddInstruction(
    223         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    224     HloInstruction* add_result =
    225         builder.AddInstruction(HloInstruction::CreateBinary(
    226             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
    227     builder.AddInstruction(
    228         HloInstruction::CreateTuple({gte_1, gte_0, add_result}));
    229 
    230     return module().AddEmbeddedComputation(builder.Build());
    231   }();
    232 
    233   HloComputation::Builder builder(TestName());
    234   auto* init_value = builder.AddInstruction(
    235       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    236   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
    237       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    238       while_body, init_value));
    239 
    240   module().AddEntryComputation(builder.Build());
    241   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    242                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    243   EXPECT_FALSE(simplified_loop);
    244 
    245   EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
    246 }
    247 
    248 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
    249   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    250   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
    251 
    252   HloComputation* while_body = [&]() {
    253     HloComputation::Builder builder(TestName() + ".while_body");
    254     HloInstruction* param = builder.AddInstruction(
    255         HloInstruction::CreateParameter(0, while_shape, "param"));
    256     HloInstruction* gte_0 = builder.AddInstruction(
    257         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    258     HloInstruction* gte_1 = builder.AddInstruction(
    259         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    260     builder.AddInstruction(
    261         HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
    262     builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
    263 
    264     return module().AddEmbeddedComputation(builder.Build());
    265   }();
    266 
    267   HloComputation::Builder builder(TestName());
    268   auto* init_value = builder.AddInstruction(
    269       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    270   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
    271       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    272       while_body, init_value));
    273 
    274   module().AddEntryComputation(builder.Build());
    275 
    276   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    277                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    278   EXPECT_FALSE(simplified_loop);
    279 
    280   EXPECT_THAT(while_inst->while_body()->instructions(),
    281               Contains(op::Outfeed()));
    282 }
    283 
    284 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
    285   // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the
    286   // bitcast either.
    287   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    288   auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
    289   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
    290 
    291   HloComputation* while_body = [&]() {
    292     HloComputation::Builder builder(TestName() + ".while_body");
    293     HloInstruction* param = builder.AddInstruction(
    294         HloInstruction::CreateParameter(0, while_shape, "param"));
    295     HloInstruction* gte_0 = builder.AddInstruction(
    296         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    297     HloInstruction* gte_1 = builder.AddInstruction(
    298         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    299     HloInstruction* bitcast_inst = builder.AddInstruction(
    300         HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
    301     builder.AddInstruction(
    302         HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
    303     builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
    304 
    305     return module().AddEmbeddedComputation(builder.Build());
    306   }();
    307 
    308   HloComputation::Builder builder(TestName());
    309   auto* init_value = builder.AddInstruction(
    310       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    311   auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
    312       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    313       while_body, init_value));
    314 
    315   module().AddEntryComputation(builder.Build());
    316 
    317   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    318                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    319   EXPECT_FALSE(simplified_loop);
    320 
    321   EXPECT_THAT(while_inst->while_body()->instructions(),
    322               Contains(op::Outfeed()));
    323   EXPECT_THAT(while_inst->while_body()->instructions(),
    324               Contains(op::Bitcast()));
    325 }
    326 
    327 TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) {
    328   // The bitcast's user can be hoisted, so hoist the bitcast too.
    329   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    330   auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
    331   Shape while_shape =
    332       ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32});
    333 
    334   HloComputation* while_body = [&]() {
    335     HloComputation::Builder builder(TestName() + ".while_body");
    336     HloInstruction* param = builder.AddInstruction(
    337         HloInstruction::CreateParameter(0, while_shape, "param"));
    338     HloInstruction* gte_0 = builder.AddInstruction(
    339         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    340     HloInstruction* gte_1 = builder.AddInstruction(
    341         HloInstruction::CreateGetTupleElement(scalar_f32, param, 1));
    342     HloInstruction* bitcast_inst = builder.AddInstruction(
    343         HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
    344     HloInstruction* add_inst =
    345         builder.AddInstruction(HloInstruction::CreateBinary(
    346             scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1));
    347     builder.AddInstruction(
    348         HloInstruction::CreateTuple({gte_0, gte_1, add_inst}));
    349 
    350     return module().AddEmbeddedComputation(builder.Build());
    351   }();
    352 
    353   HloComputation::Builder builder(TestName());
    354   auto* init_value = builder.AddInstruction(
    355       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    356   builder.AddInstruction(HloInstruction::CreateWhile(
    357       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    358       while_body, init_value));
    359 
    360   HloComputation* entry_computation =
    361       module().AddEntryComputation(builder.Build());
    362 
    363   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    364                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    365   EXPECT_TRUE(simplified_loop);
    366 
    367   HloInstruction* transformed_while;
    368   FindOnlyWhileInstruction(entry_computation, &transformed_while);
    369 
    370   EXPECT_THAT(transformed_while->while_body()->instructions(),
    371               Each(Not(op::Add())));
    372   EXPECT_THAT(transformed_while->while_body()->instructions(),
    373               Each(Not(op::Bitcast())));
    374   EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
    375   EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast()));
    376 }
    377 
    378 TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) {
    379   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    380   Shape while_shape =
    381       ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
    382 
    383   HloComputation* while_body;
    384   {
    385     HloComputation::Builder builder(TestName() + ".while_body");
    386     HloInstruction* param = builder.AddInstruction(
    387         HloInstruction::CreateParameter(0, while_shape, "param"));
    388     HloInstruction* gte_0 = builder.AddInstruction(
    389         HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
    390     HloInstruction* gte_1 = builder.AddInstruction(
    391         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    392     HloInstruction* add_result =
    393         builder.AddInstruction(HloInstruction::CreateBinary(
    394             scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
    395     TF_ASSERT_OK(param->AddControlDependencyTo(add_result));
    396     builder.AddInstruction(
    397         HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
    398 
    399     while_body = module().AddEmbeddedComputation(builder.Build());
    400   }
    401 
    402   HloComputation::Builder builder(TestName());
    403   auto* init_value = builder.AddInstruction(
    404       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    405   builder.AddInstruction(HloInstruction::CreateWhile(
    406       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    407       while_body, init_value));
    408   module().AddEntryComputation(builder.Build());
    409   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    410                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    411   EXPECT_FALSE(simplified_loop);
    412 }
    413 
    414 TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
    415   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
    416   Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
    417 
    418   HloComputation* while_body = [&]() {
    419     HloComputation::Builder builder(TestName() + ".passthrough");
    420     HloInstruction* param = builder.AddInstruction(
    421         HloInstruction::CreateParameter(0, while_shape, "param"));
    422     HloComputation* result = module().AddEmbeddedComputation(builder.Build());
    423 
    424     result->AddInstruction(
    425         HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
    426     return result;
    427   }();
    428 
    429   HloComputation::Builder builder(TestName());
    430   auto* init_value = builder.AddInstruction(
    431       HloInstruction::CreateParameter(0, while_shape, "init_value"));
    432   builder.AddInstruction(HloInstruction::CreateWhile(
    433       while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
    434       while_body, init_value));
    435   module().AddEntryComputation(builder.Build());
    436   TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
    437                           WhileLoopInvariantCodeMotion{}.Run(&module()));
    438   EXPECT_FALSE(simplified_loop);
    439 }
    440 
    441 }  // namespace
    442 }  // namespace xla
    443