Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/defuser.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     20 #include "tensorflow/compiler/xla/shape_util.h"
     21 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
     22 
     23 namespace op = xla::testing::opcode_matchers;
     24 
     25 namespace xla {
     26 namespace {
     27 
     28 class DefuserTest : public HloVerifiedTestBase {
     29  protected:
     30   // Returns the number of fusion instructions in the module.
     31   int FusionCount() {
     32     int count = 0;
     33     for (HloComputation* computation : module().computations()) {
     34       if (computation->IsFusionComputation()) {
     35         count++;
     36       }
     37     }
     38     return count;
     39   }
     40 
     41   Defuser defuser_;
     42   const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2});
     43 };
     44 
     45 TEST_F(DefuserTest, NoFusionInstruction) {
     46   auto builder = HloComputation::Builder(TestName());
     47   auto param0 =
     48       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
     49   auto param1 =
     50       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
     51   builder.AddInstruction(
     52       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
     53 
     54   module().AddEntryComputation(builder.Build());
     55   EXPECT_EQ(0, FusionCount());
     56 
     57   EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie());
     58 }
     59 
     60 TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) {
     61   auto builder = HloComputation::Builder(TestName());
     62   auto param0 =
     63       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
     64   auto param1 =
     65       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
     66   auto add = builder.AddInstruction(
     67       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
     68 
     69   auto computation = module().AddEntryComputation(builder.Build());
     70   computation->CreateFusionInstruction({add},
     71                                        HloInstruction::FusionKind::kLoop);
     72 
     73   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     74 
     75   EXPECT_EQ(1, FusionCount());
     76   EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
     77   EXPECT_EQ(0, FusionCount());
     78 
     79   EXPECT_THAT(computation->root_instruction(),
     80               op::Add(op::Parameter(), op::Parameter()));
     81 }
     82 
     83 TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) {
     84   auto builder = HloComputation::Builder(TestName());
     85   auto param0 =
     86       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
     87   auto param1 =
     88       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
     89   auto add = builder.AddInstruction(
     90       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
     91   builder.AddInstruction(
     92       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
     93 
     94   auto computation = module().AddEntryComputation(builder.Build());
     95   computation->CreateFusionInstruction({add},
     96                                        HloInstruction::FusionKind::kLoop);
     97 
     98   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion()));
     99 
    100   EXPECT_EQ(1, FusionCount());
    101   EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
    102   EXPECT_EQ(0, FusionCount());
    103 
    104   EXPECT_THAT(computation->root_instruction(),
    105               op::Negate(op::Add(op::Parameter(), op::Parameter())));
    106 }
    107 
    108 TEST_F(DefuserTest, NonTrivialFusionInstruction) {
    109   auto builder = HloComputation::Builder(TestName());
    110   auto param0 =
    111       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
    112   auto param1 =
    113       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
    114   auto param3 =
    115       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
    116   auto add = builder.AddInstruction(
    117       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
    118   auto negate = builder.AddInstruction(
    119       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
    120   auto sub = builder.AddInstruction(
    121       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
    122   auto mul = builder.AddInstruction(
    123       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
    124   auto div = builder.AddInstruction(
    125       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
    126   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    127       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
    128   auto add2 = builder.AddInstruction(
    129       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
    130 
    131   auto computation = module().AddEntryComputation(builder.Build());
    132   computation->CreateFusionInstruction(
    133       {add2, constant, div, mul, sub, negate, add},
    134       HloInstruction::FusionKind::kLoop);
    135 
    136   EXPECT_THAT(computation->root_instruction(), op::Fusion());
    137 
    138   EXPECT_EQ(1, FusionCount());
    139   EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
    140   EXPECT_EQ(0, FusionCount());
    141 
    142   EXPECT_THAT(computation->root_instruction(),
    143               op::Add(op::Constant(), op::Divide()));
    144 }
    145 
    146 TEST_F(DefuserTest, MultipleFusionInstructions) {
    147   auto builder = HloComputation::Builder(TestName());
    148   auto param0 =
    149       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
    150   auto param1 =
    151       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
    152   auto param3 =
    153       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
    154   auto add = builder.AddInstruction(
    155       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
    156   auto negate = builder.AddInstruction(
    157       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
    158   auto sub = builder.AddInstruction(
    159       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
    160   auto mul = builder.AddInstruction(
    161       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
    162   auto div = builder.AddInstruction(
    163       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
    164   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    165       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
    166   auto add2 = builder.AddInstruction(
    167       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
    168 
    169   auto computation = module().AddEntryComputation(builder.Build());
    170   computation->CreateFusionInstruction({add2, constant, div, mul},
    171                                        HloInstruction::FusionKind::kLoop);
    172   computation->CreateFusionInstruction({sub, negate, add},
    173                                        HloInstruction::FusionKind::kLoop);
    174 
    175   EXPECT_THAT(computation->root_instruction(), op::Fusion());
    176 
    177   EXPECT_EQ(2, FusionCount());
    178   EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
    179   EXPECT_EQ(0, FusionCount());
    180 
    181   EXPECT_THAT(computation->root_instruction(),
    182               op::Add(op::Constant(), op::Divide()));
    183 }
    184 
    185 TEST_F(DefuserTest, NestedFusionInstructions) {
    186   auto builder = HloComputation::Builder(TestName());
    187   auto param0 =
    188       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
    189   auto param1 =
    190       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
    191   auto add = builder.AddInstruction(
    192       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
    193   auto negate = builder.AddInstruction(
    194       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
    195 
    196   auto computation = module().AddEntryComputation(builder.Build());
    197   auto outer_fusion = computation->CreateFusionInstruction(
    198       {negate, add}, HloInstruction::FusionKind::kLoop);
    199   HloInstruction* fused_negate = outer_fusion->fused_expression_root();
    200   ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate);
    201   outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
    202       {fused_negate}, HloInstruction::FusionKind::kLoop);
    203 
    204   EXPECT_THAT(computation->root_instruction(), op::Fusion());
    205 
    206   EXPECT_EQ(2, FusionCount());
    207   EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
    208   EXPECT_EQ(0, FusionCount());
    209 
    210   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add()));
    211 }
    212 
    213 }  // namespace
    214 }  // namespace xla
    215