Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_module.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/ptr_util.h"
     20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     22 #include "tensorflow/compiler/xla/shape_util.h"
     23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/core/lib/gtl/array_slice.h"
     28 
     29 namespace xla {
     30 
     31 namespace {
     32 
     33 class HloModuleTest : public HloTestBase {
     34  protected:
     35   HloModuleTest() {}
     36 
     37   // Create a computation which returns a constant.
     38   std::unique_ptr<HloComputation> CreateConstantComputation() {
     39     auto builder = HloComputation::Builder("Constant");
     40     builder.AddInstruction(
     41         HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
     42     return builder.Build();
     43   }
     44 
     45   // Creates a computation which calls the given zero-parameter computations.
     46   std::unique_ptr<HloComputation> CreateCallComputation(
     47       tensorflow::gtl::ArraySlice<HloComputation*> computations) {
     48     auto builder = HloComputation::Builder("Call");
     49     for (auto computation : computations) {
     50       builder.AddInstruction(
     51           HloInstruction::CreateCall(r0f32_, {}, computation));
     52     }
     53     return builder.Build();
     54   }
     55 
     56   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
     57 };
     58 
     59 TEST_F(HloModuleTest, OneComputationPostOrder) {
     60   // Create a module with a single computation.
     61   auto module = CreateNewModule();
     62   auto computation = module->AddEntryComputation(CreateConstantComputation());
     63 
     64   EXPECT_THAT(module->MakeComputationPostOrder(),
     65               ::testing::ElementsAre(computation));
     66 }
     67 
     68 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
     69   // Create a module with two unconnected computations.
     70   auto module = CreateNewModule();
     71   auto computation1 = module->AddEntryComputation(CreateConstantComputation());
     72   auto computation2 =
     73       module->AddEmbeddedComputation(CreateConstantComputation());
     74 
     75   EXPECT_THAT(module->MakeComputationPostOrder(),
     76               ::testing::UnorderedElementsAre(computation1, computation2));
     77 
     78   // We specified the same name for both computations, but the HloModule should
     79   // have made the names unique.
     80   EXPECT_EQ(computation1->name(), "Constant");
     81   EXPECT_EQ(computation2->name(), "Constant.1");
     82 }
     83 
     84 TEST_F(HloModuleTest, CloneTest) {
     85   // Create and copy a module with a diamond call graph of computations.
     86   auto module = CreateNewModule();
     87   auto computation1 =
     88       module->AddEmbeddedComputation(CreateConstantComputation());
     89   auto computation2 =
     90       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
     91   auto computation3 =
     92       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
     93   module->AddEntryComputation(
     94       CreateCallComputation({computation2, computation3}));
     95 
     96   auto post_order = module->MakeComputationPostOrder();
     97   auto cloned_module = module->Clone("copy");
     98   auto post_order_copied = cloned_module->MakeComputationPostOrder();
     99 
    100   EXPECT_EQ(post_order.size(), post_order_copied.size());
    101   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
    102        origin != post_order.end() && copied != post_order_copied.end();
    103        ++origin, ++copied) {
    104     EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
    105   }
    106 }
    107 
    108 TEST_F(HloModuleTest, CloneHasFusion) {
    109   auto module = CreateNewModule();
    110 
    111   // Create the fused computation.
    112   HloComputation* fused_computation;
    113   {
    114     auto b = HloComputation::Builder("Fused");
    115     auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
    116     b.AddInstruction(
    117         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
    118     fused_computation = module->AddEmbeddedComputation(b.Build());
    119   }
    120 
    121   // Create the entry computation.
    122   {
    123     auto b = HloComputation::Builder("Entry");
    124     auto input = b.AddInstruction(
    125         HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    126     b.AddInstruction(
    127         HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
    128                                      /*operands=*/{input}, fused_computation));
    129     module->AddEntryComputation(b.Build());
    130   }
    131 
    132   auto post_order = module->MakeComputationPostOrder();
    133   auto cloned_module = module->Clone("copy");
    134   auto post_order_copied = cloned_module->MakeComputationPostOrder();
    135 
    136   EXPECT_EQ(post_order.size(), post_order_copied.size());
    137   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
    138        origin != post_order.end() && copied != post_order_copied.end();
    139        ++origin, ++copied) {
    140     if ((*origin)->name() == "Fused") {
    141       // Clone of the fused computation is handled when its fusion instruction
    142       // is cloned, which always use suffix ".clone".
    143       EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
    144     } else {
    145       EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
    146     }
    147   }
    148 }
    149 
    150 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
    151   // Create a module with a diamond call graph of computations.
    152   auto module = CreateNewModule();
    153   auto computation1 =
    154       module->AddEmbeddedComputation(CreateConstantComputation());
    155   auto computation2 =
    156       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
    157   auto computation3 =
    158       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
    159   auto computation4 = module->AddEntryComputation(
    160       CreateCallComputation({computation2, computation3}));
    161 
    162   auto post_order = module->MakeComputationPostOrder();
    163   EXPECT_THAT(post_order,
    164               ::testing::UnorderedElementsAre(computation1, computation2,
    165                                               computation3, computation4));
    166   EXPECT_EQ(post_order.back(), computation4);
    167   EXPECT_EQ(post_order.front(), computation1);
    168 }
    169 
    170 TEST_F(HloModuleTest, LargeConstantToString) {
    171   // Create a module with a single computation.
    172   auto module = CreateNewModule();
    173   auto builder = HloComputation::Builder("Constant");
    174   std::vector<float> values(16, 42.0);
    175   builder.AddInstruction(
    176       HloInstruction::CreateConstant(Literal::CreateR1<float>(values)));
    177   module->AddEntryComputation(builder.Build());
    178 
    179   EXPECT_EQ(
    180       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
    181       "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
    182       module->ToString(HloPrintOptions().set_print_large_constants(false)));
    183 
    184   EXPECT_EQ(
    185       "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
    186       "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
    187       "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
    188       module->ToString(HloPrintOptions().set_print_large_constants(true)));
    189 }
    190 
    191 TEST_F(HloModuleTest, UniqueModuleId) {
    192   auto module_a = CreateNewModule();
    193   auto module_b = CreateNewModule();
    194   EXPECT_NE(module_a->unique_id(), module_b->unique_id());
    195 }
    196 
    197 }  // namespace
    198 
    199 }  // namespace xla
    200