Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/call_graph.h"
     20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     21 #include "tensorflow/compiler/xla/shape_util.h"
     22 #include "tensorflow/compiler/xla/status_macros.h"
     23 #include "tensorflow/compiler/xla/test.h"
     24 #include "tensorflow/compiler/xla/test_helpers.h"
     25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     26 #include "tensorflow/compiler/xla/util.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/lib/core/status_test_util.h"
     29 
     30 namespace xla {
     31 namespace {
     32 
     33 class FlattenCallGraphTest : public HloTestBase {
     34  protected:
     35   // Build and return a trivial computation taking and returning a scalar.
     36   std::unique_ptr<HloComputation> MakeScalarComputation() {
     37     HloComputation::Builder builder(TestName() + ".ScalarComputation");
     38     HloInstruction* param0 = builder.AddInstruction(
     39         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     40     builder.AddInstruction(
     41         HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0));
     42     return builder.Build();
     43   }
     44 
     45   // Build and return a computation which takes a scalar and maps (kMap) the
     46   // given computation to the value 'callsites' number of times.
     47   std::unique_ptr<HloComputation> MakeMappingComputation(
     48       HloComputation* map_computation, int64 callsites) {
     49     HloComputation::Builder builder(TestName() + ".MappingComputation");
     50     HloInstruction* param0 = builder.AddInstruction(
     51         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     52     HloInstruction* last_value = param0;
     53     for (int64 i = 0; i < callsites; ++i) {
     54       last_value = builder.AddInstruction(HloInstruction::CreateMap(
     55           kScalarShape, {last_value}, map_computation));
     56     }
     57     return builder.Build();
     58   }
     59 
     60   // Build and return a computation which takes a scalar and calls (kCall) the
     61   // given computation with value 'callsites' number of times.
     62   std::unique_ptr<HloComputation> MakeCallingComputation(
     63       HloComputation* callee_computation, int64 callsites,
     64       const string& suffix = ".CallingComputation") {
     65     HloComputation::Builder builder(TestName() + suffix);
     66     HloInstruction* param0 = builder.AddInstruction(
     67         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     68     HloInstruction* last_value = param0;
     69     for (int64 i = 0; i < callsites; ++i) {
     70       last_value = builder.AddInstruction(HloInstruction::CreateCall(
     71           kScalarShape, {last_value}, callee_computation));
     72     }
     73     return builder.Build();
     74   }
     75 
     76   // Build and return a computation which takes a scalar and returns a PRED
     77   // value.
     78   std::unique_ptr<HloComputation> MakeConditionComputation() {
     79     HloComputation::Builder builder(TestName() + ".ConditionComputation");
     80     HloInstruction* param0 = builder.AddInstruction(
     81         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     82     HloInstruction* zero = builder.AddInstruction(
     83         HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
     84     builder.AddInstruction(HloInstruction::CreateBinary(
     85         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
     86     return builder.Build();
     87   }
     88 
     89   StatusOr<bool> RunFlattenCallGraph(HloModule* module) {
     90     FlattenCallGraph flatten;
     91     TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module));
     92     return result;
     93   }
     94 
     95   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
     96 };
     97 
     98 TEST_F(FlattenCallGraphTest, ComplexGraph) {
     99   // Test a call graph of a module with several computation called in various
    100   // contexts. The call graph looks like:
    101   //
    102   //      entry
    103   //      /  |
    104   //     a   |
    105   //   / | \ |
    106   //  b  |  cond
    107   //   \ |
    108   //    c
    109   //
    110   // Calls are made via kCall, kWhile, and kMap instructions.
    111   auto module = CreateNewModule();
    112   HloComputation* cond_computation =
    113       module->AddEmbeddedComputation(MakeConditionComputation());
    114   HloComputation* c_computation =
    115       module->AddEmbeddedComputation(MakeScalarComputation());
    116   HloComputation* b_computation = module->AddEmbeddedComputation(
    117       MakeMappingComputation(c_computation, /*callsites=*/1));
    118 
    119   HloComputation* a_computation;
    120   {
    121     HloComputation::Builder builder(TestName() + ".a");
    122     HloInstruction* param0 = builder.AddInstruction(
    123         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    124     HloInstruction* call = builder.AddInstruction(
    125         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
    126     builder.AddInstruction(HloInstruction::CreateWhile(
    127         kScalarShape, cond_computation, b_computation, call));
    128     a_computation = module->AddEmbeddedComputation(builder.Build());
    129   }
    130 
    131   HloComputation* entry_computation;
    132   {
    133     HloComputation::Builder builder(TestName() + ".entry");
    134     HloInstruction* param0 = builder.AddInstruction(
    135         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    136     builder.AddInstruction(HloInstruction::CreateWhile(
    137         kScalarShape, cond_computation, a_computation, param0));
    138     entry_computation = module->AddEntryComputation(builder.Build());
    139   }
    140 
    141   {
    142     TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
    143     EXPECT_TRUE(result);
    144     std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
    145     const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
    146     EXPECT_EQ(1, c_node.caller_callsites().size());
    147   }
    148 }
    149 
    150 // Test corner case of a computation used as a body and a loop condition.
    151 TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
    152   auto module = CreateNewModule();
    153   HloComputation* cond_computation;
    154   {
    155     HloComputation::Builder builder(TestName() + ".cond");
    156     HloInstruction* param0 =
    157         builder.AddInstruction(HloInstruction::CreateParameter(
    158             0, ShapeUtil::MakeShape(PRED, {}), "param0"));
    159     HloInstruction* false_constant = builder.AddInstruction(
    160         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    161     builder.AddInstruction(
    162         HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
    163                                      HloOpcode::kEq, param0, false_constant));
    164     cond_computation = module->AddEmbeddedComputation(builder.Build());
    165   }
    166 
    167   HloComputation* entry_computation;
    168   {
    169     HloComputation::Builder builder(TestName() + ".entry");
    170     HloInstruction* false_constant = builder.AddInstruction(
    171         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    172     builder.AddInstruction(HloInstruction::CreateWhile(
    173         ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
    174         false_constant));
    175     entry_computation = module->AddEntryComputation(builder.Build());
    176   }
    177 
    178   {
    179     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    180     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
    181     EXPECT_EQ(2, cond_node.caller_callsites().size());
    182   }
    183 
    184   {
    185     TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
    186     EXPECT_TRUE(result);
    187     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    188     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
    189     EXPECT_EQ(1, cond_node.caller_callsites().size());
    190   }
    191 }
    192 
    193 // Test flattening of a nested calling computations.
    194 //
    195 //   Entry
    196 //    / \
    197 //    \ /
    198 //     B
    199 //    / \
    200 //    \ /
    201 //     C
    202 //
    203 TEST_F(FlattenCallGraphTest, FlattenCalls) {
    204   auto module = CreateNewModule();
    205   HloComputation* c_computation =
    206       module->AddEmbeddedComputation(MakeScalarComputation());
    207 
    208   HloComputation* b_computation = module->AddEmbeddedComputation(
    209       MakeCallingComputation(c_computation, /*callsites=*/2, ".B"));
    210 
    211   module->AddEntryComputation(
    212       MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
    213 
    214   TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
    215   EXPECT_TRUE(result);
    216   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    217   EXPECT_EQ(7, module->computation_count());
    218 
    219   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
    220   EXPECT_EQ(1, c_node.caller_callsites().size());
    221 
    222   const CallGraphNode& b_node = call_graph->GetNode(b_computation);
    223   EXPECT_EQ(1, b_node.caller_callsites().size());
    224 }
    225 
    226 TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
    227   auto module = CreateNewModule();
    228   HloComputation* sub_computation =
    229       module->AddEmbeddedComputation(MakeScalarComputation());
    230 
    231   // Create entry computation, which is a conditional that has the same
    232   // computation in the true and false branch.
    233   HloComputation::Builder builder(TestName());
    234   auto pred = builder.AddInstruction(
    235       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    236   auto constant1 = builder.AddInstruction(
    237       HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
    238   auto constant2 = builder.AddInstruction(
    239       HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
    240   builder.AddInstruction(HloInstruction::CreateConditional(
    241       kScalarShape, pred, constant1, sub_computation, constant2,
    242       sub_computation));
    243   module->AddEntryComputation(builder.Build());
    244   EXPECT_EQ(2, module->computation_count());
    245 
    246   TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
    247   EXPECT_TRUE(result);
    248   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    249   // The true and false computations must now be different.
    250   EXPECT_EQ(3, module->computation_count());
    251 
    252   const CallGraphNode& sub_node = call_graph->GetNode(sub_computation);
    253   EXPECT_EQ(1, sub_node.caller_callsites().size());
    254 }
    255 
    256 }  // namespace
    257 }  // namespace xla
    258