Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
     17 
     18 #include "tensorflow/compiler/xla/service/copy_insertion.h"
     19 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
     20 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
     21 #include "tensorflow/compiler/xla/shape_util.h"
     22 #include "tensorflow/compiler/xla/test.h"
     23 #include "tensorflow/compiler/xla/test_helpers.h"
     24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 
     27 namespace xla {
     28 namespace {
     29 
     30 using ::testing::Eq;
     31 using ::testing::HasSubstr;
     32 
     33 class WhileTransformerTest : public HloTestBase {
     34  protected:
     35   WhileTransformerTest()
     36       : module_(CreateNewModule()),
     37         induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
     38         data_shape_(ShapeUtil::MakeShape(F32, {8})),
     39         condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
     40 
     41   std::unique_ptr<HloComputation> BuildConditionComputation(
     42       const int64 tuple_index, const int64 limit) {
     43     auto builder = HloComputation::Builder(TestName() + ".Condition");
     44     auto limit_const = builder.AddInstruction(
     45         HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
     46     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
     47         0, GetLoopStateShape(tuple_index), "loop_state"));
     48     auto induction_variable =
     49         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     50             limit_const->shape(), loop_state, tuple_index));
     51     builder.AddInstruction(
     52         HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
     53                                      induction_variable, limit_const));
     54     return builder.Build();
     55   }
     56 
     57   std::unique_ptr<HloComputation> BuildBodyComputation(
     58       const int64 ind_var_tuple_index, const int64 data_tuple_index,
     59       const int64 increment) {
     60     auto builder = HloComputation::Builder(TestName() + ".Body");
     61     // Create param instruction to access loop state.
     62     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
     63         0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
     64     // Update the induction variable GTE(ind_var_tuple_index).
     65     auto induction_variable =
     66         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     67             induction_variable_shape_, loop_state, ind_var_tuple_index));
     68     auto inc = builder.AddInstruction(
     69         HloInstruction::CreateConstant(Literal::CreateR0<int32>(increment)));
     70     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
     71         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
     72     // Update data GTE(data_tuple_index).
     73     auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
     74         data_shape_, loop_state, data_tuple_index));
     75     // Use 'induction_variable' in computation with no path to output tuple.
     76     auto update = builder.AddInstruction(
     77         HloInstruction::CreateBroadcast(data_shape_, induction_variable, {}));
     78     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
     79         data_shape_, HloOpcode::kAdd, data, update));
     80     // Create output Tuple.
     81     ind_var_tuple_index == 0
     82         ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}))
     83         : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0}));
     84     return builder.Build();
     85   }
     86 
     87   HloInstruction* BuildWhileInstruction(HloComputation* condition,
     88                                         HloComputation* body,
     89                                         const int64 ind_var_tuple_index,
     90                                         const int64 ind_var_init) {
     91     auto builder = HloComputation::Builder(TestName() + ".While");
     92     auto induction_var_init = builder.AddInstruction(
     93         HloInstruction::CreateConstant(Literal::CreateR0<int32>(ind_var_init)));
     94     auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
     95         Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
     96     auto loop_state_init =
     97         ind_var_tuple_index == 0
     98             ? builder.AddInstruction(
     99                   HloInstruction::CreateTuple({induction_var_init, data_init}))
    100             : builder.AddInstruction(
    101                   HloInstruction::CreateTuple({data_init, induction_var_init}));
    102     auto while_hlo = builder.AddInstruction(
    103         HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
    104                                     condition, body, loop_state_init));
    105     module_->AddEntryComputation(builder.Build());
    106     return while_hlo;
    107   }
    108 
    109   void RunFusionPasses() {
    110     // Run standard fusion passes.
    111     EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false)
    112                     .Run(module_.get())
    113                     .ValueOrDie());
    114     EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true)
    115                     .Run(module_.get())
    116                     .ValueOrDie());
    117   }
    118 
    119   void RunCopyInsertionPass() {
    120     HloVerifier verifier;
    121     TF_ASSERT_OK(verifier.Run(module_.get()).status());
    122     CopyInsertion copy_insertion;
    123     TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
    124   }
    125 
    126   Shape GetLoopStateShape(const int64 ind_var_tuple_index) {
    127     if (ind_var_tuple_index == 0) {
    128       return ShapeUtil::MakeTupleShape(
    129           {induction_variable_shape_, data_shape_});
    130     } else {
    131       return ShapeUtil::MakeTupleShape(
    132           {data_shape_, induction_variable_shape_});
    133     }
    134   }
    135 
    136   std::unique_ptr<HloModule> module_;
    137   Shape induction_variable_shape_;
    138   Shape data_shape_;
    139   Shape condition_result_shape_;
    140 };
    141 
    142 // TODO(b/68830972): The while transformer is far too fragile. It patterns
    143 // matches the exact expressions of opcodes. Re-enable when transformation is
    144 // more general
    145 TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
    146   // Build computation with induction variable at tuple element 0.
    147   auto condition =
    148       module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
    149   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
    150   auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
    151   // Run HLO Optimization passes.
    152   RunFusionPasses();
    153   RunCopyInsertionPass();
    154   // Run WhileTransformer.
    155   auto result = gpu::CanTransformWhileToFor(while_hlo);
    156   TF_ASSERT_OK(result.status());
    157   // Check results.
    158   EXPECT_THAT(result.ConsumeValueOrDie(),
    159               Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
    160 }
    161 
    162 // TODO(b/68830972): The while transformer is far too fragile. It patterns
    163 // matches the exact expressions of opcodes. Re-enable when transformation is
    164 // more general
    165 TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
    166   // Build computation with induction variable at tuple element 1.
    167   auto condition =
    168       module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
    169   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
    170   auto while_hlo = BuildWhileInstruction(condition, body, 1, 0);
    171   // Run HLO Optimization passes.
    172   RunFusionPasses();
    173   RunCopyInsertionPass();
    174   // Run WhileTransformer.
    175   auto result = gpu::CanTransformWhileToFor(while_hlo);
    176   TF_ASSERT_OK(result.status());
    177   // Check results.
    178   EXPECT_THAT(result.ConsumeValueOrDie(),
    179               Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
    180 }
    181 
    182 // TODO(b/68830972): The while transformer is far too fragile. It patterns
    183 // matches the exact expressions of opcodes. Re-enable when transformation is
    184 // more general
    185 TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
    186   // Build computation with invalid loop limit.
    187   auto condition =
    188       module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
    189   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
    190   auto while_hlo = BuildWhileInstruction(condition, body, 0, 10);
    191   // Run HLO Optimization passes.
    192   RunFusionPasses();
    193   RunCopyInsertionPass();
    194   // Run WhileTransformer.
    195   auto result = gpu::CanTransformWhileToFor(while_hlo);
    196   ASSERT_FALSE(result.ok());
    197   EXPECT_THAT(result.status().error_message(),
    198               HasSubstr("Loop start must be less than loop limit."));
    199 }
    200 
    201 // TODO(b/68830972): The while transformer is far too fragile. It patterns
    202 // matches the exact expressions of opcodes. Re-enable when transformation is
    203 // more general
    204 TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
    205   // Build computation with invalid loop increment.
    206   auto condition =
    207       module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
    208   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
    209   auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
    210   // Run HLO Optimization passes.
    211   RunFusionPasses();
    212   RunCopyInsertionPass();
    213   // Run WhileTransformer.
    214   auto result = gpu::CanTransformWhileToFor(while_hlo);
    215   ASSERT_FALSE(result.ok());
    216   EXPECT_THAT(result.status().error_message(),
    217               HasSubstr("Loop increment must greater than zero."));
    218 }
    219 
    220 }  // namespace
    221 }  // namespace xla
    222