1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/call_graph.h" 17 18 #include <queue> 19 20 #include "tensorflow/compiler/xla/map_util.h" 21 #include "tensorflow/compiler/xla/ptr_util.h" 22 #include "tensorflow/compiler/xla/status_macros.h" 23 #include "tensorflow/compiler/xla/util.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/strings/str_util.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/lib/strings/stringprintf.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace xla { 32 33 using ::tensorflow::strings::Appendf; 34 using ::tensorflow::strings::StrCat; 35 36 string CallContextToString(CallContext context) { 37 switch (context) { 38 case CallContext::kNone: 39 return "kNone"; 40 case CallContext::kSequential: 41 return "kSequential"; 42 case CallContext::kParallel: 43 return "kParallel"; 44 case CallContext::kBoth: 45 return "kBoth"; 46 } 47 } 48 49 std::ostream& operator<<(std::ostream& out, const CallContext& context) { 50 out << CallContextToString(context); 51 return out; 52 } 53 54 CallContext GetInstructionCallContext(const HloInstruction* instruction) { 55 switch (instruction->opcode()) { 56 case HloOpcode::kCall: 57 case HloOpcode::kConditional: 58 case HloOpcode::kWhile: 59 return CallContext::kSequential; 60 case HloOpcode::kMap: 61 case HloOpcode::kReduce: 62 case HloOpcode::kReduceWindow: 63 case HloOpcode::kSelectAndScatter: 64 case HloOpcode::kFusion: 65 return CallContext::kParallel; 66 default: 67 return CallContext::kNone; 68 } 69 } 70 71 string CallSite::ToString() const { 72 return StrCat(instruction()->name(), " calls in context ", 73 CallContextToString(context()), ": ", 74 tensorflow::str_util::Join( 75 called_computations(), ", ", 76 [](string* out, const HloComputation* computation) { 77 out->append(computation->name()); 78 })); 79 } 80 81 CallGraphNode::CallGraphNode(HloComputation* computation) 82 : computation_(computation) {} 83 84 const CallSite* CallGraphNode::GetCallSite( 85 const HloInstruction* instruction) const { 86 auto it = callsite_instructions_.find(instruction); 87 if (it == callsite_instructions_.end()) { 88 return nullptr; 89 } 90 return &callsites_[it->second]; 91 } 92 93 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { 94 caller_callsites_.push_back(caller_callsite); 95 HloComputation* caller = caller_callsite.instruction()->parent(); 96 if (!ContainsKey(caller_set_, caller)) { 97 callers_.push_back(caller); 98 caller_set_.insert(caller); 99 } 100 } 101 102 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { 103 CHECK_EQ(instruction->parent(), computation()); 104 const CallContext context = GetInstructionCallContext(instruction); 105 if (!instruction->called_computations().empty()) { 106 CHECK(context == CallContext::kSequential || 107 context == CallContext::kParallel); 108 callsite_instructions_.insert({instruction, callsites_.size()}); 109 callsites_.push_back( 110 CallSite(instruction, instruction->called_computations(), context)); 111 // Update callee computations to include any new computations called by this 112 // instruction. 113 for (auto* callee : callsites_.back().called_computations()) { 114 if (!ContainsKey(callee_set_, callee)) { 115 callees_.push_back(callee); 116 callee_set_.insert(callee); 117 } 118 } 119 } 120 } 121 122 CallGraph::CallGraph(const HloModule* module) : module_(module) {} 123 124 const CallGraphNode& CallGraph::GetNode( 125 const HloComputation* computation) const { 126 auto it = node_indices_.find(computation); 127 CHECK(it != node_indices_.end()); 128 return nodes_[it->second]; 129 } 130 131 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { 132 auto it = node_indices_.find(computation); 133 CHECK(it != node_indices_.end()); 134 return nodes_[it->second]; 135 } 136 137 bool CallGraph::DominatesHelper( 138 const HloComputation* a, const HloComputation* b, 139 tensorflow::gtl::FlatSet<const HloComputation*>* visited) const { 140 if (a == b || ContainsKey(*visited, b)) { 141 // The call graph is guaranteed to be acyclic so any previously visited node 142 // we encounter was already determined to be dominated. 143 return true; 144 } 145 146 const CallGraphNode& b_node = GetNode(b); 147 if (b_node.callers().empty()) { 148 // We reached a root node without hitting 'a'. 'a' does not dominate 'b'. 149 return false; 150 } 151 152 // Walk up the callers of 'b' until we hit 'a' or a root node (no callers). 153 visited->insert(b); 154 for (const HloComputation* b_caller : b_node.callers()) { 155 if (!DominatesHelper(a, b_caller, visited)) { 156 return false; 157 } 158 } 159 return true; 160 } 161 162 bool CallGraph::Dominates(const HloComputation* a, 163 const HloComputation* b) const { 164 tensorflow::gtl::FlatSet<const HloComputation*> visited; 165 return DominatesHelper(a, b, &visited); 166 } 167 168 namespace { 169 170 // Returns the call context of a computation which is called from contexts 'a' 171 // and 'b'. 172 CallContext UnionContexts(CallContext a, CallContext b) { 173 if (a == CallContext::kNone) { 174 return b; 175 } else if (b == CallContext::kNone) { 176 return a; 177 } else if (a == b) { 178 return a; 179 } else { 180 // Contexts are different and neither is kNone, ie one is kSequential and 181 // the other is kParallel. 182 return CallContext::kBoth; 183 } 184 } 185 186 } // namespace 187 188 void CallGraph::SetCallContexts() { 189 std::queue<CallGraphNode*> worklist; 190 191 // Initialize worklist with all roots of the call graph (computations without 192 // callers). 193 for (const HloComputation* computation : module_->computations()) { 194 CallGraphNode& node = GetNode(computation); 195 if (node.callers().empty()) { 196 node.set_context(CallContext::kSequential); 197 worklist.push(&node); 198 } 199 } 200 201 while (!worklist.empty()) { 202 CallGraphNode* node = worklist.front(); 203 worklist.pop(); 204 205 for (const CallSite& callsite : node->callsites()) { 206 for (const HloComputation* callee : callsite.called_computations()) { 207 CallGraphNode& callee_node = GetNode(callee); 208 209 // Update context of callee computation based on the callsite and its 210 // current context. 211 CallContext context_to_add; 212 if (callsite.context() == CallContext::kParallel) { 213 context_to_add = CallContext::kParallel; 214 } else { 215 CHECK_EQ(callsite.context(), CallContext::kSequential); 216 context_to_add = node->context(); 217 } 218 CallContext new_context = 219 UnionContexts(context_to_add, callee_node.context()); 220 221 if (new_context != callee_node.context()) { 222 // Context of computation has been changed so add node to worklist. 223 callee_node.set_context(new_context); 224 worklist.push(&callee_node); 225 } 226 } 227 } 228 } 229 230 // No node should have a kNone calling context. 231 for (const HloComputation* computation : module_->computations()) { 232 CHECK_NE(GetNode(computation).context(), CallContext::kNone); 233 } 234 } 235 236 /* static */ 237 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) { 238 // Constructor for CallGraph is private so MakeUnique can't be used. 239 auto call_graph = WrapUnique<CallGraph>(new CallGraph(module)); 240 241 VLOG(2) << "Building call graph for:"; 242 XLA_VLOG_LINES(2, module->ToString()); 243 244 // Construct nodes of the call graph and populate the callsites. 245 for (HloComputation* computation : module->computations()) { 246 auto it_added = call_graph->node_indices_.insert( 247 {computation, call_graph->nodes_.size()}); 248 // All computations should be unique, so the computation should not already 249 // exist in the map. 250 CHECK(it_added.second); 251 call_graph->nodes_.emplace_back(computation); 252 253 // Add all callsites in this computation. 254 for (HloInstruction* instruction : computation->instructions()) { 255 call_graph->nodes_.back().AddCallSiteForInstruction(instruction); 256 } 257 } 258 259 // Add caller callsites to each node. 260 for (const HloComputation* computation : module->computations()) { 261 for (const CallSite& callsite : 262 call_graph->GetNode(computation).callsites()) { 263 for (auto* callee : callsite.called_computations()) { 264 // Add caller callsites. 265 call_graph->GetNode(callee).AddCallerCallSite(callsite); 266 } 267 } 268 } 269 270 call_graph->SetCallContexts(); 271 XLA_VLOG_LINES(1, call_graph->ToString()); 272 273 return call_graph; 274 } 275 276 Status CallGraph::VisitNodesInternal( 277 const VisitorFunction& visitor_func, const CallGraphNode& node, 278 tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const { 279 auto pair = visited->insert(&node); 280 if (!pair.second) { 281 // Node was not inserted. Node has already been visited. 282 return Status::OK(); 283 } 284 285 for (const HloComputation* computation : node.callees()) { 286 TF_RETURN_IF_ERROR( 287 VisitNodesInternal(visitor_func, GetNode(computation), visited)); 288 } 289 290 return visitor_func(node); 291 } 292 293 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, 294 bool visit_unreachable_nodes) const { 295 tensorflow::gtl::FlatSet<const CallGraphNode*> visited; 296 if (visit_unreachable_nodes) { 297 // Traverse from all roots in the call graph. 298 for (const CallGraphNode& node : nodes()) { 299 if (node.callers().empty()) { 300 TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); 301 } 302 } 303 } else { 304 // Traverse only from the entry computation. 305 TF_RETURN_IF_ERROR(VisitNodesInternal( 306 visitor_func, GetNode(module_->entry_computation()), &visited)); 307 } 308 309 return Status::OK(); 310 } 311 312 bool CallGraph::IsFlattened() const { 313 for (const CallGraphNode& node : nodes_) { 314 if (node.context() == CallContext::kBoth) { 315 return false; 316 } 317 if (node.context() == CallContext::kSequential && 318 node.caller_callsites().size() > 1) { 319 return false; 320 } 321 } 322 return true; 323 } 324 325 std::pair<HloInstruction*, HloInstruction*> 326 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, 327 HloInstruction* b) const { 328 // Lambda which returns the next instruction in the callee->caller chain in 329 // the call graph. This is the unique instruction which calls the computation 330 // containing 'instruction'. If more than one instruction calls the 331 // computation containing 'instruction' or no instructions call the 332 // computation then nullptr is returned. 333 auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* { 334 const CallGraphNode& node = GetNode(instruction->parent()); 335 if (node.caller_callsites().size() != 1) { 336 return nullptr; 337 } 338 return node.caller_callsites()[0].instruction(); 339 }; 340 341 // Iterate through the callee->caller chains and find the earliest common 342 // element. 343 for (HloInstruction* a_ancestor = a; a_ancestor != nullptr; 344 a_ancestor = next_caller(a_ancestor)) { 345 for (HloInstruction* b_ancestor = b; b_ancestor != nullptr; 346 b_ancestor = next_caller(b_ancestor)) { 347 if (a_ancestor->parent() == b_ancestor->parent()) { 348 return {a_ancestor, b_ancestor}; 349 } 350 } 351 } 352 return {nullptr, nullptr}; 353 } 354 355 string CallGraph::ToString() const { 356 string out; 357 Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); 358 for (const CallGraphNode& node : nodes()) { 359 Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); 360 Appendf(&out, " calls:\n"); 361 for (const HloComputation* callee : node.callees()) { 362 Appendf(&out, " %s\n", callee->name().c_str()); 363 } 364 Appendf(&out, " called by:\n"); 365 for (const HloComputation* caller : node.callers()) { 366 Appendf(&out, " %s\n", caller->name().c_str()); 367 } 368 Appendf(&out, " callsites:\n"); 369 for (const CallSite& callsite : node.callsites()) { 370 Appendf(&out, " %s\n", callsite.ToString().c_str()); 371 } 372 } 373 return out; 374 } 375 376 } // namespace xla 377