Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/call_inliner.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/layout_util.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/test.h"
     31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/xla_data.pb.h"
     34 #include "tensorflow/core/lib/core/status_test_util.h"
     35 #include "tensorflow/core/lib/strings/str_util.h"
     36 
     37 namespace op = xla::testing::opcode_matchers;
     38 
     39 namespace xla {
     40 namespace {
     41 
     42 // Tests for call inlining that are most tractable at the HLO level (vs
     43 // ComputationBuilder API in call_test.cc).
     44 using CallInlinerTest = HloTestBase;
     45 
     46 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
     47   // "inner" computation just has a control dependency from the "zero" value to
     48   // the "one" value.
     49   HloComputation::Builder inner(TestName() + ".inner");
     50   HloInstruction* zero = inner.AddInstruction(
     51       HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
     52   HloInstruction* one = inner.AddInstruction(
     53       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
     54   TF_ASSERT_OK(zero->AddControlDependencyTo(one));
     55   auto module = CreateNewModule();
     56   HloComputation* inner_computation =
     57       module->AddEmbeddedComputation(inner.Build());
     58 
     59   // "outer" computation just calls the "inner" computation.
     60   HloComputation::Builder outer(TestName() + ".outer");
     61   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
     62   outer.AddInstruction(
     63       HloInstruction::CreateCall(r0f32, {}, inner_computation));
     64 
     65   auto computation = module->AddEntryComputation(outer.Build());
     66 
     67   CallInliner call_inliner;
     68   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
     69   ASSERT_TRUE(mutated);
     70   EXPECT_THAT(computation->root_instruction(), op::Constant());
     71   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
     72             42);
     73   ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size());
     74   auto prior = computation->root_instruction()->control_predecessors()[0];
     75   EXPECT_THAT(prior, op::Constant());
     76   EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
     77 }
     78 
     79 // Tests for referential transparency (a function that calls a function that
     80 // returns false should be identical to just returning false).
     81 TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
     82   const Shape pred = ShapeUtil::MakeShape(PRED, {});
     83   auto module = CreateNewModule();
     84 
     85   // Create a lambda that calls a function that returns the false predicate.
     86   // Note we also use this lambda twice by reference, just to make the test a
     87   // little trickier.
     88   HloComputation::Builder just_false(TestName() + ".false");
     89   just_false.AddInstruction(
     90       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
     91   HloComputation* false_computation =
     92       module->AddEmbeddedComputation(just_false.Build());
     93 
     94   HloComputation::Builder call_false_builder(TestName() + ".call_false");
     95   call_false_builder.AddInstruction(
     96       HloInstruction::CreateCall(pred, {}, false_computation));
     97   HloComputation* call_false =
     98       module->AddEmbeddedComputation(call_false_builder.Build());
     99 
    100   HloComputation::Builder outer(TestName() + ".outer");
    101   HloInstruction* init_value = outer.AddInstruction(
    102       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    103   outer.AddInstruction(
    104       HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
    105 
    106   auto computation = module->AddEntryComputation(outer.Build());
    107 
    108   CallInliner call_inliner;
    109   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
    110   ASSERT_TRUE(mutated);
    111   EXPECT_THAT(
    112       computation->root_instruction()->while_condition()->root_instruction(),
    113       op::Constant());
    114   EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
    115               op::Constant());
    116 }
    117 
    118 // Check CallInliner::Inline, which inlines a specific call without running the
    119 // whole pass.
    120 TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
    121   const Shape pred = ShapeUtil::MakeShape(PRED, {});
    122   auto module = CreateNewModule();
    123 
    124   HloComputation::Builder just_false(TestName() + ".false");
    125   auto* true_constant = just_false.AddInstruction(
    126       HloInstruction::CreateConstant(Literal::CreateR1<bool>({true})));
    127   auto* false_constant = just_false.AddInstruction(
    128       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    129   TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
    130   HloComputation* false_computation =
    131       module->AddEmbeddedComputation(just_false.Build());
    132 
    133   HloComputation::Builder call_false_builder(TestName() + ".call_false");
    134   HloInstruction* call = call_false_builder.AddInstruction(
    135       HloInstruction::CreateCall(pred, {}, false_computation));
    136   auto computation = module->AddEntryComputation(call_false_builder.Build());
    137 
    138   TF_ASSERT_OK(CallInliner::Inline(call).status());
    139   EXPECT_THAT(computation->root_instruction(), op::Constant());
    140   EXPECT_THAT(computation->root_instruction()->control_successors(),
    141               ElementsAre(op::Constant()));
    142 }
    143 
    144 TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
    145   const Shape f32 = ShapeUtil::MakeShape(F32, {});
    146   auto module = CreateNewModule();
    147 
    148   HloComputation::Builder outfeeder(TestName() + ".outfeeder");
    149   auto value = outfeeder.AddInstruction(
    150       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
    151   outfeeder.AddInstruction(
    152       HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/""));
    153 
    154   auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build());
    155 
    156   HloComputation::Builder outer(TestName() + ".outer");
    157   outer.AddInstruction(HloInstruction::CreateCall(
    158       ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation));
    159 
    160   module->AddEntryComputation(outer.Build());
    161 
    162   CallInliner call_inliner;
    163   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
    164   ASSERT_TRUE(mutated);
    165 }
    166 
    167 }  // namespace
    168 }  // namespace xla
    169