1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 17 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/service/call_graph.h" 20 #include "tensorflow/compiler/xla/service/hlo_computation.h" 21 #include "tensorflow/compiler/xla/shape_util.h" 22 #include "tensorflow/compiler/xla/status_macros.h" 23 #include "tensorflow/compiler/xla/test.h" 24 #include "tensorflow/compiler/xla/test_helpers.h" 25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 26 #include "tensorflow/compiler/xla/util.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/lib/core/status_test_util.h" 29 30 namespace xla { 31 namespace { 32 33 class FlattenCallGraphTest : public HloTestBase { 34 protected: 35 // Build and return a trivial computation taking and returning a scalar. 36 std::unique_ptr<HloComputation> MakeScalarComputation() { 37 HloComputation::Builder builder(TestName() + ".ScalarComputation"); 38 HloInstruction* param0 = builder.AddInstruction( 39 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 40 builder.AddInstruction( 41 HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); 42 return builder.Build(); 43 } 44 45 // Build and return a computation which takes a scalar and maps (kMap) the 46 // given computation to the value 'callsites' number of times. 47 std::unique_ptr<HloComputation> MakeMappingComputation( 48 HloComputation* map_computation, int64 callsites) { 49 HloComputation::Builder builder(TestName() + ".MappingComputation"); 50 HloInstruction* param0 = builder.AddInstruction( 51 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 52 HloInstruction* last_value = param0; 53 for (int64 i = 0; i < callsites; ++i) { 54 last_value = builder.AddInstruction(HloInstruction::CreateMap( 55 kScalarShape, {last_value}, map_computation)); 56 } 57 return builder.Build(); 58 } 59 60 // Build and return a computation which takes a scalar and calls (kCall) the 61 // given computation with value 'callsites' number of times. 62 std::unique_ptr<HloComputation> MakeCallingComputation( 63 HloComputation* callee_computation, int64 callsites, 64 const string& suffix = ".CallingComputation") { 65 HloComputation::Builder builder(TestName() + suffix); 66 HloInstruction* param0 = builder.AddInstruction( 67 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 68 HloInstruction* last_value = param0; 69 for (int64 i = 0; i < callsites; ++i) { 70 last_value = builder.AddInstruction(HloInstruction::CreateCall( 71 kScalarShape, {last_value}, callee_computation)); 72 } 73 return builder.Build(); 74 } 75 76 // Build and return a computation which takes a scalar and returns a PRED 77 // value. 78 std::unique_ptr<HloComputation> MakeConditionComputation() { 79 HloComputation::Builder builder(TestName() + ".ConditionComputation"); 80 HloInstruction* param0 = builder.AddInstruction( 81 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 82 HloInstruction* zero = builder.AddInstruction( 83 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 84 builder.AddInstruction(HloInstruction::CreateBinary( 85 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); 86 return builder.Build(); 87 } 88 89 StatusOr<bool> RunFlattenCallGraph(HloModule* module) { 90 FlattenCallGraph flatten; 91 TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module)); 92 return result; 93 } 94 95 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); 96 }; 97 98 TEST_F(FlattenCallGraphTest, ComplexGraph) { 99 // Test a call graph of a module with several computation called in various 100 // contexts. The call graph looks like: 101 // 102 // entry 103 // / | 104 // a | 105 // / | \ | 106 // b | cond 107 // \ | 108 // c 109 // 110 // Calls are made via kCall, kWhile, and kMap instructions. 111 auto module = CreateNewModule(); 112 HloComputation* cond_computation = 113 module->AddEmbeddedComputation(MakeConditionComputation()); 114 HloComputation* c_computation = 115 module->AddEmbeddedComputation(MakeScalarComputation()); 116 HloComputation* b_computation = module->AddEmbeddedComputation( 117 MakeMappingComputation(c_computation, /*callsites=*/1)); 118 119 HloComputation* a_computation; 120 { 121 HloComputation::Builder builder(TestName() + ".a"); 122 HloInstruction* param0 = builder.AddInstruction( 123 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 124 HloInstruction* call = builder.AddInstruction( 125 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); 126 builder.AddInstruction(HloInstruction::CreateWhile( 127 kScalarShape, cond_computation, b_computation, call)); 128 a_computation = module->AddEmbeddedComputation(builder.Build()); 129 } 130 131 HloComputation* entry_computation; 132 { 133 HloComputation::Builder builder(TestName() + ".entry"); 134 HloInstruction* param0 = builder.AddInstruction( 135 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 136 builder.AddInstruction(HloInstruction::CreateWhile( 137 kScalarShape, cond_computation, a_computation, param0)); 138 entry_computation = module->AddEntryComputation(builder.Build()); 139 } 140 141 { 142 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); 143 EXPECT_TRUE(result); 144 std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get()); 145 const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); 146 EXPECT_EQ(1, c_node.caller_callsites().size()); 147 } 148 } 149 150 // Test corner case of a computation used as a body and a loop condition. 151 TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { 152 auto module = CreateNewModule(); 153 HloComputation* cond_computation; 154 { 155 HloComputation::Builder builder(TestName() + ".cond"); 156 HloInstruction* param0 = 157 builder.AddInstruction(HloInstruction::CreateParameter( 158 0, ShapeUtil::MakeShape(PRED, {}), "param0")); 159 HloInstruction* false_constant = builder.AddInstruction( 160 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 161 builder.AddInstruction( 162 HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), 163 HloOpcode::kEq, param0, false_constant)); 164 cond_computation = module->AddEmbeddedComputation(builder.Build()); 165 } 166 167 HloComputation* entry_computation; 168 { 169 HloComputation::Builder builder(TestName() + ".entry"); 170 HloInstruction* false_constant = builder.AddInstruction( 171 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 172 builder.AddInstruction(HloInstruction::CreateWhile( 173 ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, 174 false_constant)); 175 entry_computation = module->AddEntryComputation(builder.Build()); 176 } 177 178 { 179 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 180 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); 181 EXPECT_EQ(2, cond_node.caller_callsites().size()); 182 } 183 184 { 185 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); 186 EXPECT_TRUE(result); 187 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 188 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); 189 EXPECT_EQ(1, cond_node.caller_callsites().size()); 190 } 191 } 192 193 // Test flattening of a nested calling computations. 194 // 195 // Entry 196 // / \ 197 // \ / 198 // B 199 // / \ 200 // \ / 201 // C 202 // 203 TEST_F(FlattenCallGraphTest, FlattenCalls) { 204 auto module = CreateNewModule(); 205 HloComputation* c_computation = 206 module->AddEmbeddedComputation(MakeScalarComputation()); 207 208 HloComputation* b_computation = module->AddEmbeddedComputation( 209 MakeCallingComputation(c_computation, /*callsites=*/2, ".B")); 210 211 module->AddEntryComputation( 212 MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); 213 214 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); 215 EXPECT_TRUE(result); 216 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 217 EXPECT_EQ(7, module->computation_count()); 218 219 const CallGraphNode& c_node = call_graph->GetNode(c_computation); 220 EXPECT_EQ(1, c_node.caller_callsites().size()); 221 222 const CallGraphNode& b_node = call_graph->GetNode(b_computation); 223 EXPECT_EQ(1, b_node.caller_callsites().size()); 224 } 225 226 TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { 227 auto module = CreateNewModule(); 228 HloComputation* sub_computation = 229 module->AddEmbeddedComputation(MakeScalarComputation()); 230 231 // Create entry computation, which is a conditional that has the same 232 // computation in the true and false branch. 233 HloComputation::Builder builder(TestName()); 234 auto pred = builder.AddInstruction( 235 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 236 auto constant1 = builder.AddInstruction( 237 HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f))); 238 auto constant2 = builder.AddInstruction( 239 HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f))); 240 builder.AddInstruction(HloInstruction::CreateConditional( 241 kScalarShape, pred, constant1, sub_computation, constant2, 242 sub_computation)); 243 module->AddEntryComputation(builder.Build()); 244 EXPECT_EQ(2, module->computation_count()); 245 246 TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); 247 EXPECT_TRUE(result); 248 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 249 // The true and false computations must now be different. 250 EXPECT_EQ(3, module->computation_count()); 251 252 const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); 253 EXPECT_EQ(1, sub_node.caller_callsites().size()); 254 } 255 256 } // namespace 257 } // namespace xla 258