Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     28 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 
     32 namespace xla {
     33 namespace {
     34 
     35 class HloOrderingTest : public HloTestBase {};
     36 
     37 TEST_F(HloOrderingTest, LastUseScheduledFirst) {
     38   // Tests scheduling of the following HLO code:
     39   //
     40   //   %ab = abs(%param)
     41   //   %exp = exp(%param)
     42   //   %add = add(%ab, %exp)
     43   //   %negate = negate(%exp)
     44   //   %sub = subtract(%add, %negate)
     45   //
     46   // %add should be scheduled before %negate because %add is the last (and only)
     47   // use of %ab. Scheduling %add first then frees up %ab's buffer.
     48   const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
     49   auto builder = HloComputation::Builder(TestName());
     50   auto param =
     51       builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
     52   auto ab = builder.AddInstruction(
     53       HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
     54   auto exp = builder.AddInstruction(
     55       HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
     56 
     57   auto add = builder.AddInstruction(
     58       HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
     59   auto negate = builder.AddInstruction(
     60       HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
     61   auto sub = builder.AddInstruction(
     62       HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
     63 
     64   auto module = CreateNewModule();
     65   module->AddEntryComputation(builder.Build());
     66 
     67   TF_ASSERT_OK_AND_ASSIGN(
     68       SequentialHloOrdering::HloModuleSequence sequence,
     69       CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) {
     70         return ShapeUtil::ByteSizeOf(buffer.shape());
     71       }));
     72   // Verify that all instructions are in the sequence.
     73   EXPECT_EQ(module->entry_computation()->instruction_count(),
     74             sequence.at(module->entry_computation()).size());
     75 
     76   // The first instruction should be the parameter and the last the root "sub".
     77   EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
     78   EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
     79 
     80   SequentialHloOrdering ordering(module.get(), sequence);
     81   EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
     82 }
     83 
     84 TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
     85   // Tests the ordering of instructions in different computations using the
     86   // following HLO code:
     87   //
     88   // Entry computation:
     89   //   %x = Call(A, {})
     90   //   %y = Call(B, {%x})
     91   //
     92   // Computation A:
     93   //   %a = Call(C, {})
     94   //
     95   // Computation B:
     96   //   %b = Call(C, {})
     97   //
     98   // Computation C:
     99   //   %c = Constant(42.0f)
    100   //
    101   // This results in a diamond-shaped callgraph.
    102   auto module = CreateNewModule();
    103   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
    104 
    105   auto builder_c = HloComputation::Builder("C");
    106   HloInstruction* c = builder_c.AddInstruction(
    107       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    108   HloComputation* computation_c =
    109       module->AddEmbeddedComputation(builder_c.Build());
    110 
    111   auto builder_b = HloComputation::Builder("B");
    112   builder_b.AddInstruction(
    113       HloInstruction::CreateParameter(0, scalar_shape, "param"));
    114   HloInstruction* b = builder_b.AddInstruction(
    115       HloInstruction::CreateCall(scalar_shape, {}, computation_c));
    116   HloComputation* computation_b =
    117       module->AddEmbeddedComputation(builder_b.Build());
    118 
    119   auto builder_a = HloComputation::Builder("A");
    120   HloInstruction* a = builder_a.AddInstruction(
    121       HloInstruction::CreateCall(scalar_shape, {}, computation_c));
    122   HloComputation* computation_a =
    123       module->AddEmbeddedComputation(builder_a.Build());
    124 
    125   auto builder = HloComputation::Builder(TestName());
    126   HloInstruction* x = builder.AddInstruction(
    127       HloInstruction::CreateCall(scalar_shape, {}, computation_a));
    128   HloInstruction* y = builder.AddInstruction(
    129       HloInstruction::CreateCall(scalar_shape, {x}, computation_b));
    130   module->AddEntryComputation(builder.Build());
    131 
    132   DependencyHloOrdering ordering(module.get());
    133   EXPECT_TRUE(ordering.ExecutesBefore(x, y));
    134   EXPECT_FALSE(ordering.ExecutesBefore(y, x));
    135 
    136   EXPECT_TRUE(ordering.ExecutesBefore(a, b));
    137   EXPECT_FALSE(ordering.ExecutesBefore(b, a));
    138 
    139   EXPECT_FALSE(ordering.ExecutesBefore(a, x));
    140   EXPECT_TRUE(ordering.ExecutesBefore(a, y));
    141   EXPECT_FALSE(ordering.ExecutesBefore(x, a));
    142   EXPECT_FALSE(ordering.ExecutesBefore(y, a));
    143 
    144   EXPECT_FALSE(ordering.ExecutesBefore(b, x));
    145   EXPECT_FALSE(ordering.ExecutesBefore(b, y));
    146   EXPECT_TRUE(ordering.ExecutesBefore(x, b));
    147   EXPECT_FALSE(ordering.ExecutesBefore(y, b));
    148 
    149   // Instruction 'c' is called from multiple callsites and should be unordered
    150   // relative to all other instructions in the module.
    151   EXPECT_FALSE(ordering.ExecutesBefore(c, a));
    152   EXPECT_FALSE(ordering.ExecutesBefore(c, b));
    153   EXPECT_FALSE(ordering.ExecutesBefore(c, x));
    154   EXPECT_FALSE(ordering.ExecutesBefore(c, y));
    155   EXPECT_FALSE(ordering.ExecutesBefore(a, c));
    156   EXPECT_FALSE(ordering.ExecutesBefore(b, c));
    157   EXPECT_FALSE(ordering.ExecutesBefore(x, c));
    158   EXPECT_FALSE(ordering.ExecutesBefore(y, c));
    159 }
    160 
    161 TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
    162   // Tests the ordering of instructions in the body and condition of a while
    163   // instruction. HLO code:
    164   //
    165   // body(F32[]) %param):
    166   //   %negate = Negate(%param)
    167   //
    168   // condition(F32[] %param):
    169   //   %convert = Convert<PRED>(%param)
    170   //
    171   // entry:
    172   //   %constant = Constant(1.0)
    173   //   return While(%constant, body, condition)
    174   //
    175   auto module = CreateNewModule();
    176   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
    177 
    178   auto body_builder = HloComputation::Builder("body");
    179   auto body_param = body_builder.AddInstruction(
    180       HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
    181   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
    182       scalar_shape, HloOpcode::kNegate, body_param));
    183   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
    184 
    185   auto cond_builder = HloComputation::Builder("condition");
    186   auto cond_param = cond_builder.AddInstruction(
    187       HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
    188   auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
    189       ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
    190   HloComputation* condition =
    191       module->AddEmbeddedComputation(cond_builder.Build());
    192 
    193   auto builder = HloComputation::Builder(TestName());
    194   auto constant = builder.AddInstruction(
    195       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    196   auto xla_while = builder.AddInstruction(
    197       HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
    198   module->AddEntryComputation(builder.Build());
    199 
    200   DependencyHloOrdering ordering(module.get());
    201   EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while));
    202   EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param));
    203   EXPECT_TRUE(ordering.ExecutesBefore(constant, convert));
    204   EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param));
    205   EXPECT_TRUE(ordering.ExecutesBefore(constant, negate));
    206 
    207   // The while should be unordered relative to the body and condition
    208   // instructions.
    209   EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param));
    210   EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param));
    211   EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while));
    212   EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while));
    213 
    214   // Condition instructions should be ordered before body instructions.
    215   EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param));
    216   EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param));
    217   EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate));
    218   EXPECT_TRUE(ordering.ExecutesBefore(convert, negate));
    219 
    220   EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
    221 }
    222 
    223 TEST_F(HloOrderingTest, ValuesInWhileComputations) {
    224   // Tests the ordering of values (defined by dataflow analysis) in the body and
    225   // condition of a while instruction. HLO code:
    226   //
    227   // body(F32[]) %param):
    228   //   %negate = Negate(%param)
    229   //
    230   // condition(F32[] %param):
    231   //   %convert = Convert<PRED>(%param)
    232   //
    233   // entry:
    234   //   %constant = Constant(1.0)
    235   //   %while = While(%constant, body, condition)
    236   //   %add = Add(%constant, %while)
    237   //
    238   auto module = CreateNewModule();
    239   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
    240 
    241   auto body_builder = HloComputation::Builder("body");
    242   auto body_param = body_builder.AddInstruction(
    243       HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
    244   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
    245       scalar_shape, HloOpcode::kNegate, body_param));
    246   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
    247 
    248   auto cond_builder = HloComputation::Builder("condition");
    249   auto cond_param = cond_builder.AddInstruction(
    250       HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
    251   auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
    252       ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
    253   HloComputation* condition =
    254       module->AddEmbeddedComputation(cond_builder.Build());
    255 
    256   auto builder = HloComputation::Builder(TestName());
    257   auto constant = builder.AddInstruction(
    258       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    259   auto xla_while = builder.AddInstruction(
    260       HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
    261   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    262       scalar_shape, HloOpcode::kAdd, constant, xla_while));
    263   module->AddEntryComputation(builder.Build());
    264 
    265   TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
    266                           HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
    267   DependencyHloOrdering ordering(module.get());
    268 
    269   // Init value is defined before the while, but live range is not before the
    270   // while because of the use of the init value in the add.
    271   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
    272                                        dataflow->GetValueDefinedAt(xla_while)));
    273   EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
    274       dataflow->GetValueDefinedAt(constant),
    275       dataflow->GetValueDefinedAt(xla_while), *dataflow));
    276   EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
    277                                     dataflow->GetValueDefinedAt(xla_while),
    278                                     *dataflow));
    279 
    280   // Any value defined in the body or condition is defined before the while, and
    281   // has a live range strictly before the while.
    282   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
    283                                        dataflow->GetValueDefinedAt(xla_while)));
    284   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
    285       dataflow->GetValueDefinedAt(negate),
    286       dataflow->GetValueDefinedAt(xla_while), *dataflow));
    287   EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
    288                                      dataflow->GetValueDefinedAt(xla_while),
    289                                      *dataflow));
    290 
    291   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
    292                                        dataflow->GetValueDefinedAt(xla_while)));
    293   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
    294       dataflow->GetValueDefinedAt(convert),
    295       dataflow->GetValueDefinedAt(xla_while), *dataflow));
    296   EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
    297                                      dataflow->GetValueDefinedAt(xla_while),
    298                                      *dataflow));
    299 
    300   // The live range of the while should be before the add.
    301   EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
    302                                        dataflow->GetValueDefinedAt(add)));
    303   ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
    304 
    305   const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
    306   EXPECT_EQ(while_use.instruction, add);
    307   EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
    308       while_use, dataflow->GetValueDefinedAt(add), *dataflow));
    309   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
    310       dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
    311       *dataflow));
    312 }
    313 
    314 // Regression test for HloOrdering::ToString() crashing when fed a computation
    315 // containing a fusion node.
    316 TEST_F(HloOrderingTest, ToStringDoesNotCrash) {
    317   const char* module_str = R"(
    318 HloModule test_module
    319 
    320 body.v8 {
    321   prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
    322   get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0
    323   constant.1 = s32[] constant(1)
    324   add = s32[] add(get-tuple-element.4, constant.1)
    325   get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3
    326   get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1
    327   get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2
    328   ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7)
    329 }
    330 
    331 condition.v4 {
    332   constant.2 = s32[] constant(2)
    333   prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
    334   get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
    335   ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
    336 }
    337 
    338 fused_computation {
    339   get-tuple-element.5.param_1 = f32[3]{0} parameter(1)
    340   get-tuple-element.6.param_2 = f32[3]{0} parameter(2)
    341   add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2)
    342   get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0)
    343   ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1)
    344 }
    345 
    346 ENTRY while.v11 {
    347   constant.5 = s32[] constant(0)
    348   constant.6 = f32[3]{0} constant({1, 1, 1})
    349   constant.7 = f32[3]{0} constant({2, 2, 2})
    350   constant.8 = f32[3]{0} constant({3, 3, 3})
    351   tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8)
    352   while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8
    353   get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3
    354   get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1
    355   get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2
    356   ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation
    357 })";
    358 
    359   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    360                           tools::Parse(module_str));
    361   DependencyHloOrdering ordering(module.get());
    362   ordering.ToString();  // Shouldn't crash.
    363 }
    364 
    365 }  // namespace
    366 }  // namespace xla
    367