1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 22 #include "tensorflow/compiler/xla/test.h" 23 #include "tensorflow/compiler/xla/tests/test_utils.h" 24 #include "tensorflow/compiler/xla/xla.pb.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 namespace xla { 28 namespace { 29 30 using ::tensorflow::strings::StrCat; 31 using ::testing::HasSubstr; 32 33 string TestName() { 34 return ::testing::UnitTest::GetInstance()->current_test_info()->name(); 35 } 36 37 class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { 38 public: 39 string RenderGraph(const string& graph, GraphKind graph_kind, 40 const DebugOptions& debug_options) override { 41 return graph; 42 } 43 44 private: 45 string last_graph_; 46 }; 47 48 XLA_REGISTER_GRAPH_RENDERER(DotRenderer); 49 50 TEST(HloGraphDumperTest, NestedFusion) { 51 HloComputation::Builder b("b"); 52 53 // Build param0 + param1 + param2 + param3 + param4. 54 auto shape = ShapeUtil::MakeShape(F32, {10, 100}); 55 std::vector<HloInstruction*> params; 56 for (int i = 0; i <= 4; ++i) { 57 params.push_back(b.AddInstruction( 58 HloInstruction::CreateParameter(i, shape, StrCat("param", i)))); 59 } 60 std::vector<HloInstruction*> sums; 61 sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( 62 shape, HloOpcode::kAdd, params[0], params[1]))); 63 for (int i = 0; i <= 2; ++i) { 64 sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( 65 shape, HloOpcode::kAdd, sums[i], params[i + 2]))); 66 } 67 68 HloModule m(TestName()); 69 m.AddEntryComputation(b.Build()); 70 HloComputation* root_computation = m.entry_computation(); 71 72 // Fuse into fusion(param0 + param1 + param2 + param3 + param4). 73 auto* outer_fusion = root_computation->CreateFusionInstruction( 74 {sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop); 75 76 // Fusing invalidates the pointers in sums -- the instructions are cloned when 77 // they're moved to the new computation. Get the updated pointers to sums. 78 std::vector<HloInstruction*> fused_sums; 79 for (auto* instr : outer_fusion->fused_instructions_computation() 80 ->MakeInstructionPostOrder()) { 81 if (instr->opcode() == HloOpcode::kAdd) { 82 fused_sums.push_back(instr); 83 } 84 } 85 86 // Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4). 87 auto* inner_fusion = 88 outer_fusion->fused_instructions_computation()->CreateFusionInstruction( 89 {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop); 90 91 // Generate the graph; all nodes should be present. 92 string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"", 93 DebugOptions()); 94 for (const HloComputation* computation : 95 {root_computation, // 96 inner_fusion->fused_instructions_computation(), 97 outer_fusion->fused_instructions_computation()}) { 98 for (const HloInstruction* instruction : computation->instructions()) { 99 EXPECT_THAT(graph, HasSubstr(instruction->name())); 100 } 101 } 102 103 // Dump a neighborhood around one of the inner sum nodes. We don't really 104 // care that the outer nodes are omitted -- whether they are or not is based 105 // fiddly heuristics -- but we do care that the node we asked for is printed. 106 const HloInstruction* inner_sum = nullptr; 107 for (const HloInstruction* instruction : 108 inner_fusion->fused_instructions_computation()->instructions()) { 109 if (instruction->opcode() == HloOpcode::kAdd) { 110 inner_sum = instruction; 111 break; 112 } 113 } 114 ASSERT_NE(inner_sum, nullptr); 115 EXPECT_THAT( 116 hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1), 117 HasSubstr(inner_sum->name())); 118 } 119 120 TEST(HloGraphDumperTest, Constant) { 121 HloComputation::Builder b("b"); 122 auto instruction = b.AddInstruction( 123 HloInstruction::CreateConstant(Literal::CreateR0<float>(-42))); 124 instruction->set_name("i_am_a_constant_root_instruction"); 125 HloModule m(TestName()); 126 HloComputation* root_computation = m.AddEntryComputation(b.Build()); 127 string graph = hlo_graph_dumper::DumpGraph( 128 *root_computation, /*label=*/"an_empty_graph", DebugOptions()); 129 EXPECT_THAT(graph, HasSubstr("an_empty_graph")); 130 EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); 131 } 132 133 } // anonymous namespace 134 } // namespace xla 135