Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     22 #include "tensorflow/compiler/xla/test.h"
     23 #include "tensorflow/compiler/xla/tests/test_utils.h"
     24 #include "tensorflow/compiler/xla/xla.pb.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 namespace xla {
     28 namespace {
     29 
     30 using ::tensorflow::strings::StrCat;
     31 using ::testing::HasSubstr;
     32 
     33 string TestName() {
     34   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
     35 }
     36 
     37 class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
     38  public:
     39   string RenderGraph(const string& graph, GraphKind graph_kind,
     40                      const DebugOptions& debug_options) override {
     41     return graph;
     42   }
     43 
     44  private:
     45   string last_graph_;
     46 };
     47 
     48 XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
     49 
     50 TEST(HloGraphDumperTest, NestedFusion) {
     51   HloComputation::Builder b("b");
     52 
     53   // Build param0 + param1 + param2 + param3 + param4.
     54   auto shape = ShapeUtil::MakeShape(F32, {10, 100});
     55   std::vector<HloInstruction*> params;
     56   for (int i = 0; i <= 4; ++i) {
     57     params.push_back(b.AddInstruction(
     58         HloInstruction::CreateParameter(i, shape, StrCat("param", i))));
     59   }
     60   std::vector<HloInstruction*> sums;
     61   sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
     62       shape, HloOpcode::kAdd, params[0], params[1])));
     63   for (int i = 0; i <= 2; ++i) {
     64     sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
     65         shape, HloOpcode::kAdd, sums[i], params[i + 2])));
     66   }
     67 
     68   HloModule m(TestName());
     69   m.AddEntryComputation(b.Build());
     70   HloComputation* root_computation = m.entry_computation();
     71 
     72   // Fuse into fusion(param0 + param1 + param2 + param3 + param4).
     73   auto* outer_fusion = root_computation->CreateFusionInstruction(
     74       {sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop);
     75 
     76   // Fusing invalidates the pointers in sums -- the instructions are cloned when
     77   // they're moved to the new computation.  Get the updated pointers to sums.
     78   std::vector<HloInstruction*> fused_sums;
     79   for (auto* instr : outer_fusion->fused_instructions_computation()
     80                          ->MakeInstructionPostOrder()) {
     81     if (instr->opcode() == HloOpcode::kAdd) {
     82       fused_sums.push_back(instr);
     83     }
     84   }
     85 
     86   // Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4).
     87   auto* inner_fusion =
     88       outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
     89           {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop);
     90 
     91   // Generate the graph; all nodes should be present.
     92   string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"",
     93                                              DebugOptions());
     94   for (const HloComputation* computation :
     95        {root_computation,  //
     96         inner_fusion->fused_instructions_computation(),
     97         outer_fusion->fused_instructions_computation()}) {
     98     for (const HloInstruction* instruction : computation->instructions()) {
     99       EXPECT_THAT(graph, HasSubstr(instruction->name()));
    100     }
    101   }
    102 
    103   // Dump a neighborhood around one of the inner sum nodes.  We don't really
    104   // care that the outer nodes are omitted -- whether they are or not is based
    105   // fiddly heuristics -- but we do care that the node we asked for is printed.
    106   const HloInstruction* inner_sum = nullptr;
    107   for (const HloInstruction* instruction :
    108        inner_fusion->fused_instructions_computation()->instructions()) {
    109     if (instruction->opcode() == HloOpcode::kAdd) {
    110       inner_sum = instruction;
    111       break;
    112     }
    113   }
    114   ASSERT_NE(inner_sum, nullptr);
    115   EXPECT_THAT(
    116       hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1),
    117       HasSubstr(inner_sum->name()));
    118 }
    119 
    120 TEST(HloGraphDumperTest, Constant) {
    121   HloComputation::Builder b("b");
    122   auto instruction = b.AddInstruction(
    123       HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
    124   instruction->set_name("i_am_a_constant_root_instruction");
    125   HloModule m(TestName());
    126   HloComputation* root_computation = m.AddEntryComputation(b.Build());
    127   string graph = hlo_graph_dumper::DumpGraph(
    128       *root_computation, /*label=*/"an_empty_graph", DebugOptions());
    129   EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
    130   EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
    131 }
    132 
    133 }  // anonymous namespace
    134 }  // namespace xla
    135