1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_module.h" 17 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/ptr_util.h" 20 #include "tensorflow/compiler/xla/service/hlo_computation.h" 21 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/core/lib/gtl/array_slice.h" 28 29 namespace xla { 30 31 namespace { 32 33 class HloModuleTest : public HloTestBase { 34 protected: 35 HloModuleTest() {} 36 37 // Create a computation which returns a constant. 38 std::unique_ptr<HloComputation> CreateConstantComputation() { 39 auto builder = HloComputation::Builder("Constant"); 40 builder.AddInstruction( 41 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 42 return builder.Build(); 43 } 44 45 // Creates a computation which calls the given zero-parameter computations. 46 std::unique_ptr<HloComputation> CreateCallComputation( 47 tensorflow::gtl::ArraySlice<HloComputation*> computations) { 48 auto builder = HloComputation::Builder("Call"); 49 for (auto computation : computations) { 50 builder.AddInstruction( 51 HloInstruction::CreateCall(r0f32_, {}, computation)); 52 } 53 return builder.Build(); 54 } 55 56 Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); 57 }; 58 59 TEST_F(HloModuleTest, OneComputationPostOrder) { 60 // Create a module with a single computation. 61 auto module = CreateNewModule(); 62 auto computation = module->AddEntryComputation(CreateConstantComputation()); 63 64 EXPECT_THAT(module->MakeComputationPostOrder(), 65 ::testing::ElementsAre(computation)); 66 } 67 68 TEST_F(HloModuleTest, TwoComputationsPostOrder) { 69 // Create a module with two unconnected computations. 70 auto module = CreateNewModule(); 71 auto computation1 = module->AddEntryComputation(CreateConstantComputation()); 72 auto computation2 = 73 module->AddEmbeddedComputation(CreateConstantComputation()); 74 75 EXPECT_THAT(module->MakeComputationPostOrder(), 76 ::testing::UnorderedElementsAre(computation1, computation2)); 77 78 // We specified the same name for both computations, but the HloModule should 79 // have made the names unique. 80 EXPECT_EQ(computation1->name(), "Constant"); 81 EXPECT_EQ(computation2->name(), "Constant.1"); 82 } 83 84 TEST_F(HloModuleTest, CloneTest) { 85 // Create and copy a module with a diamond call graph of computations. 86 auto module = CreateNewModule(); 87 auto computation1 = 88 module->AddEmbeddedComputation(CreateConstantComputation()); 89 auto computation2 = 90 module->AddEmbeddedComputation(CreateCallComputation({computation1})); 91 auto computation3 = 92 module->AddEmbeddedComputation(CreateCallComputation({computation1})); 93 module->AddEntryComputation( 94 CreateCallComputation({computation2, computation3})); 95 96 auto post_order = module->MakeComputationPostOrder(); 97 auto cloned_module = module->Clone("copy"); 98 auto post_order_copied = cloned_module->MakeComputationPostOrder(); 99 100 EXPECT_EQ(post_order.size(), post_order_copied.size()); 101 for (auto origin = post_order.begin(), copied = post_order_copied.begin(); 102 origin != post_order.end() && copied != post_order_copied.end(); 103 ++origin, ++copied) { 104 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); 105 } 106 } 107 108 TEST_F(HloModuleTest, CloneHasFusion) { 109 auto module = CreateNewModule(); 110 111 // Create the fused computation. 112 HloComputation* fused_computation; 113 { 114 auto b = HloComputation::Builder("Fused"); 115 auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); 116 b.AddInstruction( 117 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x)); 118 fused_computation = module->AddEmbeddedComputation(b.Build()); 119 } 120 121 // Create the entry computation. 122 { 123 auto b = HloComputation::Builder("Entry"); 124 auto input = b.AddInstruction( 125 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 126 b.AddInstruction( 127 HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, 128 /*operands=*/{input}, fused_computation)); 129 module->AddEntryComputation(b.Build()); 130 } 131 132 auto post_order = module->MakeComputationPostOrder(); 133 auto cloned_module = module->Clone("copy"); 134 auto post_order_copied = cloned_module->MakeComputationPostOrder(); 135 136 EXPECT_EQ(post_order.size(), post_order_copied.size()); 137 for (auto origin = post_order.begin(), copied = post_order_copied.begin(); 138 origin != post_order.end() && copied != post_order_copied.end(); 139 ++origin, ++copied) { 140 if ((*origin)->name() == "Fused") { 141 // Clone of the fused computation is handled when its fusion instruction 142 // is cloned, which always use suffix ".clone". 143 EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name()); 144 } else { 145 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); 146 } 147 } 148 } 149 150 TEST_F(HloModuleTest, DiamondComputationsPostOrder) { 151 // Create a module with a diamond call graph of computations. 152 auto module = CreateNewModule(); 153 auto computation1 = 154 module->AddEmbeddedComputation(CreateConstantComputation()); 155 auto computation2 = 156 module->AddEmbeddedComputation(CreateCallComputation({computation1})); 157 auto computation3 = 158 module->AddEmbeddedComputation(CreateCallComputation({computation1})); 159 auto computation4 = module->AddEntryComputation( 160 CreateCallComputation({computation2, computation3})); 161 162 auto post_order = module->MakeComputationPostOrder(); 163 EXPECT_THAT(post_order, 164 ::testing::UnorderedElementsAre(computation1, computation2, 165 computation3, computation4)); 166 EXPECT_EQ(post_order.back(), computation4); 167 EXPECT_EQ(post_order.front(), computation1); 168 } 169 170 TEST_F(HloModuleTest, LargeConstantToString) { 171 // Create a module with a single computation. 172 auto module = CreateNewModule(); 173 auto builder = HloComputation::Builder("Constant"); 174 std::vector<float> values(16, 42.0); 175 builder.AddInstruction( 176 HloInstruction::CreateConstant(Literal::CreateR1<float>(values))); 177 module->AddEntryComputation(builder.Build()); 178 179 EXPECT_EQ( 180 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " 181 "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", 182 module->ToString(HloPrintOptions().set_print_large_constants(false))); 183 184 EXPECT_EQ( 185 "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " 186 "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " 187 "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", 188 module->ToString(HloPrintOptions().set_print_large_constants(true))); 189 } 190 191 TEST_F(HloModuleTest, UniqueModuleId) { 192 auto module_a = CreateNewModule(); 193 auto module_b = CreateNewModule(); 194 EXPECT_NE(module_a->unique_id(), module_b->unique_id()); 195 } 196 197 } // namespace 198 199 } // namespace xla 200