Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/liveness_util.h"
     17 
     18 #include <memory>
     19 
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     23 
     24 namespace xla {
     25 namespace {
     26 
     27 class PointsToAnalysisTestBase : public HloTestBase {
     28  protected:
     29   void BuildModule(std::unique_ptr<HloComputation> computation) {
     30     module_ = CreateNewModule();
     31     computation_ = module_->AddEntryComputation(std::move(computation));
     32   }
     33 
     34   void RunAnalysis() {
     35     CHECK_NOTNULL(module_.get());
     36     points_to_analysis_ =
     37         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
     38     dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
     39   }
     40 
     41   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
     42     BuildModule(std::move(computation));
     43     RunAnalysis();
     44   }
     45 
     46   std::unique_ptr<HloModule> module_;
     47   HloComputation* computation_ = nullptr;
     48   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
     49   std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
     50 };
     51 
     52 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
     53 
     54 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
     55   auto builder = HloComputation::Builder(TestName());
     56 
     57   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
     58   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
     59       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
     60   auto gte0 = builder.AddInstruction(
     61       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
     62   auto gte1 = builder.AddInstruction(
     63       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
     64   builder.AddInstruction(
     65       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
     66 
     67   BuildModuleAndRunAnalysis(builder.Build());
     68 
     69   // GetTupleElement instructions only access the top-level buffer of their
     70   // operand.
     71   EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_));
     72   EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
     73   EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
     74   EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
     75 
     76   EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_));
     77   EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_));
     78   EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_));
     79   EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_));
     80 }
     81 
     82 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
     83   auto builder = HloComputation::Builder(TestName());
     84 
     85   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
     86   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
     87       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
     88   auto gte0 = builder.AddInstruction(
     89       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
     90   auto gte1 = builder.AddInstruction(
     91       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
     92 
     93   // Create a DynamicUpdateSlice instruction of tuple element 1.
     94   auto starts = builder.AddInstruction(
     95       HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
     96   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
     97       Literal::CreateR1<float>({2.f, 2.f, 2.f})));
     98   auto dynamic_update_slice =
     99       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    100           data_shape, gte1, update, starts));
    101   builder.AddInstruction(
    102       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
    103 
    104   BuildModule(builder.Build());
    105   auto fusion = computation_->CreateFusionInstruction(
    106       {dynamic_update_slice, starts, update, gte1},
    107       HloInstruction::FusionKind::kLoop);
    108   RunAnalysis();
    109 
    110   // The fusion instruction never uses tuple element 0, but does use element 1.
    111   EXPECT_TRUE(
    112       DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
    113   EXPECT_FALSE(
    114       DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
    115 
    116   EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_));
    117   EXPECT_FALSE(
    118       DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_));
    119 }
    120 
    121 class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
    122 
    123 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
    124   auto builder = HloComputation::Builder(TestName());
    125 
    126   Shape shape = ShapeUtil::MakeShape(F32, {8});
    127   auto param = builder.AddInstruction(
    128       HloInstruction::CreateParameter(0, shape, "param"));
    129   auto exp = builder.AddInstruction(
    130       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
    131   auto log = builder.AddInstruction(
    132       HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
    133 
    134   BuildModuleAndRunAnalysis(builder.Build());
    135 
    136   EXPECT_TRUE(
    137       CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
    138   EXPECT_TRUE(
    139       CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
    140 
    141   EXPECT_TRUE(
    142       CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
    143   EXPECT_TRUE(
    144       CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_));
    145 }
    146 
    147 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
    148   auto builder = HloComputation::Builder(TestName());
    149 
    150   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
    151   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
    152   auto param0 = builder.AddInstruction(
    153       HloInstruction::CreateParameter(0, in_shape, "param0"));
    154   auto param1 = builder.AddInstruction(
    155       HloInstruction::CreateParameter(1, in_shape, "param1"));
    156   auto result = builder.AddInstruction(
    157       HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
    158 
    159   BuildModuleAndRunAnalysis(builder.Build());
    160 
    161   EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
    162                                              *points_to_analysis_));
    163   EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
    164                                              *points_to_analysis_));
    165 
    166   EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
    167                                              *dataflow_analysis_));
    168   EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
    169                                              *dataflow_analysis_));
    170 }
    171 
    172 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
    173   auto builder = HloComputation::Builder(TestName());
    174 
    175   Shape shape = ShapeUtil::MakeShape(F32, {8});
    176   auto param = builder.AddInstruction(
    177       HloInstruction::CreateParameter(0, shape, "param"));
    178   auto exp = builder.AddInstruction(
    179       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
    180   auto copy = builder.AddInstruction(
    181       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
    182 
    183   BuildModuleAndRunAnalysis(builder.Build());
    184 
    185   EXPECT_TRUE(
    186       CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
    187   EXPECT_TRUE(
    188       CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_));
    189 
    190   EXPECT_TRUE(
    191       CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
    192   EXPECT_TRUE(
    193       CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_));
    194 }
    195 
    196 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
    197   auto builder = HloComputation::Builder(TestName());
    198 
    199   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
    200   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
    201       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
    202   auto gte0 = builder.AddInstruction(
    203       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
    204   auto gte1 = builder.AddInstruction(
    205       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
    206 
    207   // Create a DynamicUpdateSlice instruction of tuple element 1.
    208   auto starts = builder.AddInstruction(
    209       HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
    210   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
    211       Literal::CreateR1<float>({2.f, 2.f, 2.f})));
    212   auto dynamic_update_slice =
    213       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    214           data_shape, gte1, update, starts));
    215   builder.AddInstruction(
    216       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
    217 
    218   BuildModule(builder.Build());
    219   auto fusion = computation_->CreateFusionInstruction(
    220       {dynamic_update_slice, starts, update, gte1},
    221       HloInstruction::FusionKind::kLoop);
    222   RunAnalysis();
    223 
    224   // The fusion instruction can share with tuple element 1.
    225   EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
    226                                              *points_to_analysis_));
    227   EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
    228                                             *points_to_analysis_));
    229 
    230   EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
    231                                              *dataflow_analysis_));
    232   EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
    233                                             *dataflow_analysis_));
    234 }
    235 
    236 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
    237   auto builder = HloComputation::Builder(TestName());
    238 
    239   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
    240   Shape update_shape = ShapeUtil::MakeShape(F32, {4});
    241   Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
    242   auto data = builder.AddInstruction(
    243       HloInstruction::CreateParameter(0, data_shape, "data"));
    244   auto update = builder.AddInstruction(
    245       HloInstruction::CreateParameter(1, update_shape, "update"));
    246   auto starts = builder.AddInstruction(
    247       HloInstruction::CreateParameter(2, starts_shape, "starts"));
    248   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    249       data_shape, data, update, starts));
    250 
    251   BuildModuleAndRunAnalysis(builder.Build());
    252 
    253   // The DynamicUpdateSlice instruction can share with the data operand, but not
    254   // with update or starts.
    255   EXPECT_TRUE(
    256       CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
    257   EXPECT_FALSE(
    258       CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
    259   EXPECT_FALSE(
    260       CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
    261 
    262   EXPECT_TRUE(
    263       CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_));
    264   EXPECT_FALSE(
    265       CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_));
    266   EXPECT_FALSE(
    267       CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_));
    268 }
    269 
    270 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
    271   auto builder = HloComputation::Builder(TestName());
    272   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
    273 
    274   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
    275       Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
    276   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
    277       Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
    278 
    279   DotDimensionNumbers dot_dnums;
    280   dot_dnums.add_lhs_contracting_dimensions(1);
    281   dot_dnums.add_rhs_contracting_dimensions(0);
    282   auto dot = builder.AddInstruction(
    283       HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
    284 
    285   auto one = builder.AddInstruction(
    286       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    287   auto add_operand = builder.AddInstruction(
    288       HloInstruction::CreateBroadcast(data_shape, one, {1}));
    289 
    290   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    291       data_shape, HloOpcode::kAdd, dot, add_operand));
    292 
    293   BuildModule(builder.Build());
    294   auto fusion = computation_->CreateFusionInstruction(
    295       {add, dot}, HloInstruction::FusionKind::kOutput);
    296   RunAnalysis();
    297 
    298   // Output fused dot add should be able to share buffer with 'add_operand'.
    299   EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
    300                                             *points_to_analysis_));
    301 
    302   EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
    303                                             *dataflow_analysis_));
    304 }
    305 
    306 TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
    307   auto builder = HloComputation::Builder(TestName());
    308   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
    309 
    310   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
    311       Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
    312   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
    313       Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
    314   auto b_t = builder.AddInstruction(
    315       HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
    316 
    317   DotDimensionNumbers dot_dnums;
    318   dot_dnums.add_lhs_contracting_dimensions(1);
    319   dot_dnums.add_rhs_contracting_dimensions(0);
    320   auto dot = builder.AddInstruction(
    321       HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
    322 
    323   auto one = builder.AddInstruction(
    324       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    325   auto add_operand = builder.AddInstruction(
    326       HloInstruction::CreateBroadcast(data_shape, one, {1}));
    327 
    328   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    329       data_shape, HloOpcode::kAdd, dot, add_operand));
    330 
    331   BuildModule(builder.Build());
    332 
    333   auto nested_fusion = computation_->CreateFusionInstruction(
    334       {dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
    335 
    336   auto fusion = computation_->CreateFusionInstruction(
    337       {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
    338   RunAnalysis();
    339 
    340   // Output fused transpose-dot-add should be share buffer with 'add_operand'.
    341   EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
    342                                             *points_to_analysis_));
    343 
    344   EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
    345                                             *dataflow_analysis_));
    346 }
    347 
    348 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
    349   auto builder = HloComputation::Builder(TestName());
    350   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
    351 
    352   auto one = builder.AddInstruction(
    353       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    354   auto operand = builder.AddInstruction(
    355       HloInstruction::CreateBroadcast(data_shape, one, {1}));
    356 
    357   auto reverse = builder.AddInstruction(
    358       HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
    359 
    360   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
    361       Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
    362 
    363   auto add = builder.AddInstruction(
    364       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
    365 
    366   BuildModule(builder.Build());
    367   auto fusion = computation_->CreateFusionInstruction(
    368       {add, two, reverse}, HloInstruction::FusionKind::kOutput);
    369   RunAnalysis();
    370 
    371   // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
    372   EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
    373                                              *points_to_analysis_));
    374 
    375   EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
    376                                              *dataflow_analysis_));
    377 }
    378 
    379 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
    380   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
    381 
    382   auto make_cond = [this, &data_shape]() {
    383     auto builder = HloComputation::Builder(TestName() + ".Cond");
    384     auto data = builder.AddInstruction(
    385         HloInstruction::CreateParameter(0, data_shape, "data"));
    386     builder.AddInstruction(HloInstruction::CreateBinary(
    387         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
    388     return builder.Build();
    389   };
    390 
    391   auto make_body = [this, &data_shape]() {
    392     auto builder = HloComputation::Builder(TestName() + ".Body");
    393     auto data = builder.AddInstruction(
    394         HloInstruction::CreateParameter(0, data_shape, "data"));
    395     builder.AddInstruction(
    396         HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
    397     return builder.Build();
    398   };
    399 
    400   module_ = CreateNewModule();
    401   HloComputation* cond_computation =
    402       module_->AddEmbeddedComputation(make_cond());
    403   HloComputation* body_computation =
    404       module_->AddEmbeddedComputation(make_body());
    405 
    406   auto builder = HloComputation::Builder(TestName());
    407   auto data = builder.AddInstruction(
    408       HloInstruction::CreateParameter(0, data_shape, "data"));
    409   auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
    410       data_shape, cond_computation, body_computation, data));
    411   computation_ = module_->AddEntryComputation(builder.Build());
    412 
    413   RunAnalysis();
    414 
    415   // The While instruction can share with the data operand.
    416   EXPECT_TRUE(
    417       CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
    418 
    419   EXPECT_TRUE(
    420       CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
    421 }
    422 
    423 // Tests that Call can alias operand buffer if the only use of the operand
    424 // in the called computation is an elementwise instruction.
    425 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
    426   Shape shape = ShapeUtil::MakeShape(F32, {8});
    427   // Build sub-computation with fusion root.
    428   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
    429   auto sub_param = sub_builder.AddInstruction(
    430       HloInstruction::CreateParameter(0, shape, "sub_param"));
    431   auto one = sub_builder.AddInstruction(
    432       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    433   auto ones = sub_builder.AddInstruction(
    434       HloInstruction::CreateBroadcast(shape, one, {1}));
    435   auto add = sub_builder.AddInstruction(
    436       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
    437 
    438   module_ = CreateNewModule();
    439   auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
    440   sub_computation->CreateFusionInstruction({add, ones},
    441                                            HloInstruction::FusionKind::kLoop);
    442 
    443   // Build entry-computation with kCall which calls 'sub_computation'.
    444   auto builder = HloComputation::Builder(TestName());
    445 
    446   auto param = builder.AddInstruction(
    447       HloInstruction::CreateParameter(0, shape, "param"));
    448   auto reverse =
    449       builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
    450   auto call = builder.AddInstruction(
    451       HloInstruction::CreateCall(shape, {reverse}, sub_computation));
    452   computation_ = module_->AddEntryComputation(builder.Build());
    453 
    454   RunAnalysis();
    455 
    456   EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
    457                                             *points_to_analysis_));
    458   EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {},
    459                                             *dataflow_analysis_));
    460 }
    461 
    462 }  // namespace
    463 }  // namespace xla
    464