Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/test.h"
     28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/core/lib/core/status_test_util.h"
     31 
     32 namespace op = xla::testing::opcode_matchers;
     33 
     34 namespace xla {
     35 namespace {
     36 
     37 class TupleSimplifierTest : public HloTestBase {
     38  protected:
     39   void Run(HloModule* module, bool change_expected) {
     40     TupleSimplifier simplifier;
     41     auto changed_status = simplifier.Run(module);
     42     TF_ASSERT_OK(changed_status.status());
     43     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
     44   }
     45 
     46   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
     47   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
     48       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
     49        ShapeUtil::MakeShape(F32, {})});
     50 };
     51 
     52 TEST_F(TupleSimplifierTest, TupleOfParameters) {
     53   // A Tuple constructed of a bunch of parameters should not be changed.
     54   HloComputation::Builder builder(TestName());
     55   HloInstruction* param0 = builder.AddInstruction(
     56       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
     57   HloInstruction* param1 = builder.AddInstruction(
     58       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
     59   HloInstruction* param2 = builder.AddInstruction(
     60       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
     61   builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
     62   auto module = CreateNewModule();
     63   module->AddEntryComputation(builder.Build());
     64 
     65   Run(module.get(), /*change_expected=*/false);
     66 }
     67 
     68 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
     69   // A GTE of a tuple parameter should not be changed.
     70   HloComputation::Builder builder(TestName());
     71   HloInstruction* param = builder.AddInstruction(
     72       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
     73   builder.AddInstruction(
     74       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
     75   auto module = CreateNewModule();
     76   module->AddEntryComputation(builder.Build());
     77 
     78   Run(module.get(), /*change_expected=*/false);
     79 }
     80 
     81 TEST_F(TupleSimplifierTest, GteOfTuple) {
     82   // A GTE of a Tuple should be short-circuited.
     83   HloComputation::Builder builder(TestName());
     84   HloInstruction* param0 = builder.AddInstruction(
     85       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
     86   HloInstruction* param1 = builder.AddInstruction(
     87       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
     88   HloInstruction* param2 = builder.AddInstruction(
     89       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
     90   HloInstruction* tuple = builder.AddInstruction(
     91       HloInstruction::CreateTuple({param0, param1, param2}));
     92   HloInstruction* gte = builder.AddInstruction(
     93       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
     94 
     95   auto module = CreateNewModule();
     96   auto computation = module->AddEntryComputation(builder.Build());
     97 
     98   EXPECT_THAT(computation->root_instruction(), gte);
     99 
    100   Run(module.get(), /*change_expected=*/true);
    101 
    102   EXPECT_THAT(computation->root_instruction(), param1);
    103 }
    104 
    105 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
    106   // Verify a chain of GTE/Tuple instructions is collapsed.
    107   HloComputation::Builder builder(TestName());
    108   HloInstruction* param = builder.AddInstruction(
    109       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    110 
    111   const int kChainLength = 10;
    112   HloInstruction* element = param;
    113   for (int i = 0; i < kChainLength; ++i) {
    114     HloInstruction* tuple = builder.AddInstruction(
    115         HloInstruction::CreateTuple({element, element, element}));
    116     element = builder.AddInstruction(
    117         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
    118   }
    119   builder.AddInstruction(
    120       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
    121 
    122   auto module = CreateNewModule();
    123   auto computation = module->AddEntryComputation(builder.Build());
    124 
    125   EXPECT_THAT(computation->root_instruction(),
    126               op::Negate(op::GetTupleElement(op::Tuple())));
    127 
    128   Run(module.get(), /*change_expected=*/true);
    129 
    130   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
    131 }
    132 
    133 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
    134   // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
    135   // to some depth with a chain of Tuple instructions, then extracted with a
    136   // chain of GTE instructions.
    137   HloComputation::Builder builder(TestName());
    138   HloInstruction* param = builder.AddInstruction(
    139       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
    140 
    141   const int kNestingDepth = 5;
    142   HloInstruction* nested_tuple = param;
    143   for (int i = 0; i < kNestingDepth; ++i) {
    144     nested_tuple = builder.AddInstruction(
    145         HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
    146   }
    147 
    148   HloInstruction* element = nested_tuple;
    149   for (int i = 0; i < kNestingDepth; ++i) {
    150     element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    151         ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
    152   }
    153 
    154   auto module = CreateNewModule();
    155   auto computation = module->AddEntryComputation(builder.Build());
    156 
    157   EXPECT_THAT(computation->root_instruction(), element);
    158 
    159   Run(module.get(), /*change_expected=*/true);
    160 
    161   EXPECT_THAT(computation->root_instruction(), param);
    162 }
    163 
    164 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
    165   // Verify that a tuple constructed of GTE instructions operating on the same
    166   // tuple are collapsed.
    167   HloComputation::Builder builder(TestName());
    168   HloInstruction* tuple_param = builder.AddInstruction(
    169       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
    170   HloInstruction* gte0 = builder.AddInstruction(
    171       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
    172   HloInstruction* gte1 = builder.AddInstruction(
    173       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
    174   HloInstruction* gte2 = builder.AddInstruction(
    175       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
    176   HloInstruction* tuple =
    177       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
    178 
    179   auto module = CreateNewModule();
    180   auto computation = module->AddEntryComputation(builder.Build());
    181 
    182   EXPECT_THAT(computation->root_instruction(), tuple);
    183 
    184   Run(module.get(), /*change_expected=*/true);
    185 
    186   EXPECT_THAT(computation->root_instruction(), tuple_param);
    187 }
    188 
    189 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
    190   // Verify that a tuple->GTE->tuple construct is not simplified if the input
    191   // and output tuple are not compatible shapes.
    192   HloComputation::Builder builder(TestName());
    193   HloInstruction* tuple_param = builder.AddInstruction(
    194       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
    195   HloInstruction* gte0 = builder.AddInstruction(
    196       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
    197   HloInstruction* gte1 = builder.AddInstruction(
    198       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
    199   // Output tuple has only two elements. Parameter tuple has three elements so
    200   // simplification is not possible.
    201   HloInstruction* tuple =
    202       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
    203 
    204   auto module = CreateNewModule();
    205   auto computation = module->AddEntryComputation(builder.Build());
    206 
    207   EXPECT_THAT(computation->root_instruction(), tuple);
    208 
    209   Run(module.get(), /*change_expected=*/false);
    210 
    211   EXPECT_THAT(computation->root_instruction(), tuple);
    212 }
    213 
    214 }  // namespace
    215 }  // namespace xla
    216