Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     21 #include "tensorflow/compiler/xla/service/hlo_module.h"
     22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     25 #include "tensorflow/compiler/xla/tests/test_utils.h"
     26 
     27 namespace xla {
     28 
     29 class HloSubcomputationUnificationTest : public HloTestBase {
     30  protected:
     31   HloSubcomputationUnificationTest() {}
     32 
     33   std::unique_ptr<HloComputation> CreateR0S32IdentityComputation() {
     34     auto builder = HloComputation::Builder("Identity");
     35     builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
     36     return builder.Build();
     37   }
     38 
     39   std::unique_ptr<HloComputation> CreateR0S32AdditionComputation() {
     40     auto builder = HloComputation::Builder("Addition");
     41     auto x =
     42         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
     43     auto y =
     44         builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y"));
     45     builder.AddInstruction(
     46         HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
     47     return builder.Build();
     48   }
     49 
     50   std::unique_ptr<HloComputation> CreateR1S32AdditionComputation(
     51       const Shape& shape) {
     52     auto builder = HloComputation::Builder("Addition");
     53     auto x =
     54         builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
     55     auto y =
     56         builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y"));
     57     builder.AddInstruction(
     58         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y));
     59     return builder.Build();
     60   }
     61 
     62   Shape r0s32_ = ShapeUtil::MakeShape(S32, {});
     63   Shape r0f32_ = ShapeUtil::MakeShape(S32, {});
     64   Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5});
     65   Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3});
     66 };
     67 
     68 TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
     69   auto module = CreateNewModule();
     70   auto builder = HloComputation::Builder(TestName());
     71 
     72   auto callee1 =
     73       module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
     74   auto callee2 =
     75       module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
     76 
     77   auto constant = builder.AddInstruction(
     78       HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
     79   auto x = builder.AddInstruction(
     80       HloInstruction::CreateCall(r0s32_, {constant}, callee1));
     81   auto y = builder.AddInstruction(
     82       HloInstruction::CreateCall(r0s32_, {constant}, callee2));
     83   builder.AddInstruction(
     84       HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
     85 
     86   module->AddEntryComputation(builder.Build());
     87 
     88   EXPECT_EQ(3, module->computation_count());
     89   EXPECT_NE(x->to_apply(), y->to_apply());
     90   if (VLOG_IS_ON(1)) {
     91     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
     92                                 "before unification",
     93                                 module->config().debug_options());
     94   }
     95   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
     96   if (VLOG_IS_ON(1)) {
     97     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
     98                                 "after unification",
     99                                 module->config().debug_options());
    100   }
    101   EXPECT_EQ(2, module->computation_count());
    102   EXPECT_EQ(x->to_apply(), y->to_apply());
    103 }
    104 
    105 TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
    106   auto module = CreateNewModule();
    107   auto builder = HloComputation::Builder(TestName());
    108 
    109   auto callee1 =
    110       module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
    111   auto callee2 =
    112       module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
    113 
    114   auto constant1 = builder.AddInstruction(
    115       HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
    116   auto constant2 = builder.AddInstruction(
    117       HloInstruction::CreateConstant(Literal::CreateR0<int32>(3)));
    118   auto x = builder.AddInstruction(
    119       HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
    120   auto y = builder.AddInstruction(
    121       HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2));
    122   builder.AddInstruction(
    123       HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
    124 
    125   module->AddEntryComputation(builder.Build());
    126 
    127   EXPECT_EQ(3, module->computation_count());
    128   EXPECT_NE(x->to_apply(), y->to_apply());
    129   if (VLOG_IS_ON(1)) {
    130     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
    131                                 "before unification",
    132                                 module->config().debug_options());
    133   }
    134   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
    135   if (VLOG_IS_ON(1)) {
    136     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
    137                                 "after unification",
    138                                 module->config().debug_options());
    139   }
    140   EXPECT_EQ(2, module->computation_count());
    141   EXPECT_EQ(x->to_apply(), y->to_apply());
    142 }
    143 
    144 // Do not unify subcomputations with different parameter shapes.
    145 TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
    146   auto module = CreateNewModule();
    147   auto builder = HloComputation::Builder(TestName());
    148 
    149   auto callee1 =
    150       module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_));
    151   auto callee2 =
    152       module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_));
    153 
    154   auto param1 = builder.AddInstruction(
    155       HloInstruction::CreateParameter(0, r1s32_5_, "param1"));
    156   auto param2 = builder.AddInstruction(
    157       HloInstruction::CreateParameter(1, r1s32_5_, "param2"));
    158   auto x = builder.AddInstruction(
    159       HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1));
    160   auto y = builder.AddInstruction(
    161       HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2));
    162   builder.AddInstruction(HloInstruction::CreateConcatenate(
    163       ShapeUtil::MakeShape(S32, {8}), {x, y}, 0));
    164 
    165   module->AddEntryComputation(builder.Build());
    166 
    167   EXPECT_EQ(3, module->computation_count());
    168   EXPECT_NE(x->to_apply(), y->to_apply());
    169   if (VLOG_IS_ON(1)) {
    170     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
    171                                 "before unification",
    172                                 module->config().debug_options());
    173   }
    174   EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
    175   if (VLOG_IS_ON(1)) {
    176     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
    177                                 "after unification",
    178                                 module->config().debug_options());
    179   }
    180   EXPECT_EQ(3, module->computation_count());
    181   EXPECT_NE(x->to_apply(), y->to_apply());
    182 }
    183 
    184 // Regression test for b/31466798. Checks that entry_computation is still valid
    185 // after unification.
    186 TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
    187   auto module = CreateNewModule();
    188   for (int i = 0; i < 2; ++i) {
    189     HloComputation::Builder builder("pow");
    190     auto x =
    191         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
    192     auto y =
    193         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
    194     builder.AddInstruction(
    195         HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y));
    196     if (i == 0) {
    197       module->AddEmbeddedComputation(builder.Build());
    198     } else {
    199       module->AddEntryComputation(builder.Build());
    200     }
    201   }
    202 
    203   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
    204   EXPECT_EQ(1, module->computation_count());
    205   EXPECT_EQ(*module->computations().begin(), module->entry_computation());
    206 }
    207 
    208 }  // namespace xla
    209