Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/call_graph.h"
     17 
     18 #include <queue>
     19 
     20 #include "absl/container/flat_hash_set.h"
     21 #include "absl/memory/memory.h"
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/strings/str_format.h"
     24 #include "absl/strings/str_join.h"
     25 #include "tensorflow/compiler/xla/map_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace xla {
     33 
     34 using absl::StrAppendFormat;
     35 using absl::StrCat;
     36 
     37 string CallContextToString(CallContext context) {
     38   switch (context) {
     39     case CallContext::kNone:
     40       return "kNone";
     41     case CallContext::kSequential:
     42       return "kSequential";
     43     case CallContext::kParallel:
     44       return "kParallel";
     45     case CallContext::kBoth:
     46       return "kBoth";
     47   }
     48 }
     49 
     50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
     51   out << CallContextToString(context);
     52   return out;
     53 }
     54 
     55 CallContext GetInstructionCallContext(HloOpcode opcode) {
     56   switch (opcode) {
     57     case HloOpcode::kCall:
     58     case HloOpcode::kConditional:
     59     case HloOpcode::kWhile:
     60       return CallContext::kSequential;
     61     case HloOpcode::kAllReduce:
     62     case HloOpcode::kMap:
     63     case HloOpcode::kReduce:
     64     case HloOpcode::kReduceWindow:
     65     case HloOpcode::kScatter:
     66     case HloOpcode::kSelectAndScatter:
     67     case HloOpcode::kSort:
     68     case HloOpcode::kFusion:
     69       return CallContext::kParallel;
     70     default:
     71       return CallContext::kNone;
     72   }
     73 }
     74 
     75 string CallSite::ToString() const {
     76   return StrCat(
     77       instruction()->name(), " calls in context ",
     78       CallContextToString(context()), ": ",
     79       absl::StrJoin(called_computations(), ", ",
     80                     [](string* out, const HloComputation* computation) {
     81                       out->append(computation->name());
     82                     }));
     83 }
     84 
     85 CallGraphNode::CallGraphNode(HloComputation* computation)
     86     : computation_(computation) {}
     87 
     88 const CallSite* CallGraphNode::GetCallSite(
     89     const HloInstruction* instruction) const {
     90   auto it = callsite_instructions_.find(instruction);
     91   if (it == callsite_instructions_.end()) {
     92     return nullptr;
     93   }
     94   return &callsites_[it->second];
     95 }
     96 
     97 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
     98   caller_callsites_.push_back(caller_callsite);
     99   HloComputation* caller = caller_callsite.instruction()->parent();
    100   if (!ContainsKey(caller_set_, caller)) {
    101     callers_.push_back(caller);
    102     caller_set_.insert(caller);
    103   }
    104 }
    105 
    106 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
    107   CHECK_EQ(instruction->parent(), computation());
    108   const CallContext context = GetInstructionCallContext(instruction->opcode());
    109   if (!instruction->called_computations().empty()) {
    110     CHECK(context == CallContext::kSequential ||
    111           context == CallContext::kParallel);
    112     callsite_instructions_.insert({instruction, callsites_.size()});
    113     callsites_.push_back(
    114         CallSite(instruction, instruction->called_computations(), context));
    115     // Update callee computations to include any new computations called by this
    116     // instruction.
    117     for (auto* callee : callsites_.back().called_computations()) {
    118       if (!ContainsKey(callee_set_, callee)) {
    119         callees_.push_back(callee);
    120         callee_set_.insert(callee);
    121       }
    122     }
    123   }
    124 }
    125 
    126 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
    127 
    128 const CallGraphNode& CallGraph::GetNode(
    129     const HloComputation* computation) const {
    130   auto it = node_indices_.find(computation);
    131   CHECK(it != node_indices_.end());
    132   return nodes_[it->second];
    133 }
    134 
    135 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
    136   auto it = node_indices_.find(computation);
    137   CHECK(it != node_indices_.end());
    138   return nodes_[it->second];
    139 }
    140 
    141 bool CallGraph::DominatesHelper(
    142     const HloComputation* a, const HloComputation* b,
    143     absl::flat_hash_set<const HloComputation*>* visited) const {
    144   if (a == b || ContainsKey(*visited, b)) {
    145     // The call graph is guaranteed to be acyclic so any previously visited node
    146     // we encounter was already determined to be dominated.
    147     return true;
    148   }
    149 
    150   const CallGraphNode& b_node = GetNode(b);
    151   if (b_node.callers().empty()) {
    152     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
    153     return false;
    154   }
    155 
    156   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
    157   visited->insert(b);
    158   for (const HloComputation* b_caller : b_node.callers()) {
    159     if (!DominatesHelper(a, b_caller, visited)) {
    160       return false;
    161     }
    162   }
    163   return true;
    164 }
    165 
    166 bool CallGraph::Dominates(const HloComputation* a,
    167                           const HloComputation* b) const {
    168   absl::flat_hash_set<const HloComputation*> visited;
    169   return DominatesHelper(a, b, &visited);
    170 }
    171 
    172 namespace {
    173 
    174 // Returns the call context of a computation which is called from contexts 'a'
    175 // and 'b'.
    176 CallContext UnionContexts(CallContext a, CallContext b) {
    177   if (a == CallContext::kNone) {
    178     return b;
    179   } else if (b == CallContext::kNone) {
    180     return a;
    181   } else if (a == b) {
    182     return a;
    183   } else {
    184     // Contexts are different and neither is kNone, ie one is kSequential and
    185     // the other is kParallel.
    186     return CallContext::kBoth;
    187   }
    188 }
    189 
    190 }  // namespace
    191 
    192 void CallGraph::SetCallContexts() {
    193   std::queue<CallGraphNode*> worklist;
    194 
    195   // Initialize worklist with all roots of the call graph (computations without
    196   // callers).
    197   for (const HloComputation* computation : module_->computations()) {
    198     CallGraphNode& node = GetNode(computation);
    199     if (node.callers().empty()) {
    200       node.set_context(CallContext::kSequential);
    201       worklist.push(&node);
    202     }
    203   }
    204 
    205   while (!worklist.empty()) {
    206     CallGraphNode* node = worklist.front();
    207     worklist.pop();
    208 
    209     for (const CallSite& callsite : node->callsites()) {
    210       for (const HloComputation* callee : callsite.called_computations()) {
    211         CallGraphNode& callee_node = GetNode(callee);
    212 
    213         // Update context of callee computation based on the callsite and its
    214         // current context.
    215         CallContext context_to_add;
    216         if (callsite.context() == CallContext::kParallel) {
    217           context_to_add = CallContext::kParallel;
    218         } else {
    219           CHECK_EQ(callsite.context(), CallContext::kSequential);
    220           context_to_add = node->context();
    221         }
    222         CallContext new_context =
    223             UnionContexts(context_to_add, callee_node.context());
    224 
    225         if (new_context != callee_node.context()) {
    226           // Context of computation has been changed so add node to worklist.
    227           callee_node.set_context(new_context);
    228           worklist.push(&callee_node);
    229         }
    230       }
    231     }
    232   }
    233 
    234   // No node should have a kNone calling context.
    235   for (const HloComputation* computation : module_->computations()) {
    236     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
    237   }
    238 }
    239 
    240 void CallGraph::SetNodeDepths() {
    241   std::queue<CallGraphNode*> worklist;
    242 
    243   // Initialize node depths to -1.
    244   for (CallGraphNode& node : nodes_) {
    245     node.set_depth(-1);
    246   }
    247 
    248   // Initialize worklist with all roots of the call graph (computations without
    249   // callers).
    250   for (const HloComputation* computation : module_->computations()) {
    251     CallGraphNode& node = GetNode(computation);
    252     if (node.callers().empty()) {
    253       node.set_depth(0);
    254       worklist.push(&node);
    255     }
    256   }
    257 
    258   while (!worklist.empty()) {
    259     CallGraphNode* node = worklist.front();
    260     worklist.pop();
    261     for (const HloComputation* callee : node->callees()) {
    262       CallGraphNode& callee_node = GetNode(callee);
    263       if (callee_node.depth() < node->depth() + 1) {
    264         callee_node.set_depth(node->depth() + 1);
    265         worklist.push(&callee_node);
    266       }
    267     }
    268   }
    269 
    270   for (CallGraphNode& node : nodes_) {
    271     CHECK_NE(node.depth(), -1);
    272   }
    273 }
    274 
    275 /* static */
    276 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
    277   // Constructor for CallGraph is private so absl::make_unique can't be used.
    278   auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
    279 
    280   VLOG(2) << "Building call graph for:";
    281   XLA_VLOG_LINES(2, module->ToString());
    282 
    283   // Construct nodes of the call graph and populate the callsites.
    284   for (HloComputation* computation : module->computations()) {
    285     auto it_added = call_graph->node_indices_.insert(
    286         {computation, call_graph->nodes_.size()});
    287     // All computations should be unique, so the computation should not already
    288     // exist in the map.
    289     CHECK(it_added.second);
    290     call_graph->nodes_.emplace_back(computation);
    291 
    292     // Add all callsites in this computation.
    293     for (HloInstruction* instruction : computation->instructions()) {
    294       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
    295     }
    296   }
    297 
    298   // Add caller callsites to each node.
    299   for (const HloComputation* computation : module->computations()) {
    300     for (const CallSite& callsite :
    301          call_graph->GetNode(computation).callsites()) {
    302       for (auto* callee : callsite.called_computations()) {
    303         // Add caller callsites.
    304         call_graph->GetNode(callee).AddCallerCallSite(callsite);
    305       }
    306     }
    307   }
    308 
    309   call_graph->SetCallContexts();
    310   call_graph->SetNodeDepths();
    311 
    312   XLA_VLOG_LINES(1, call_graph->ToString());
    313 
    314   return call_graph;
    315 }
    316 
    317 Status CallGraph::VisitNodesInternal(
    318     const VisitorFunction& visitor_func, const CallGraphNode& node,
    319     absl::flat_hash_set<const CallGraphNode*>* visited) const {
    320   auto pair = visited->insert(&node);
    321   if (!pair.second) {
    322     // Node was not inserted. Node has already been visited.
    323     return Status::OK();
    324   }
    325 
    326   for (const HloComputation* computation : node.callees()) {
    327     TF_RETURN_IF_ERROR(
    328         VisitNodesInternal(visitor_func, GetNode(computation), visited));
    329   }
    330 
    331   return visitor_func(node);
    332 }
    333 
    334 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
    335                              bool visit_unreachable_nodes) const {
    336   absl::flat_hash_set<const CallGraphNode*> visited;
    337   if (visit_unreachable_nodes) {
    338     // Traverse from all roots in the call graph.
    339     for (const CallGraphNode& node : nodes()) {
    340       if (node.callers().empty()) {
    341         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
    342       }
    343     }
    344   } else {
    345     // Traverse only from the entry computation.
    346     TF_RETURN_IF_ERROR(VisitNodesInternal(
    347         visitor_func, GetNode(module_->entry_computation()), &visited));
    348   }
    349 
    350   return Status::OK();
    351 }
    352 
    353 bool CallGraph::IsFlattened() const {
    354   for (const CallGraphNode& node : nodes_) {
    355     if (node.context() == CallContext::kBoth) {
    356       return false;
    357     }
    358     if (node.context() == CallContext::kSequential &&
    359         node.caller_callsites().size() > 1) {
    360       return false;
    361     }
    362   }
    363   return true;
    364 }
    365 
    366 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
    367     HloComputation* c) {
    368   std::vector<HloInstruction*> callers;
    369   for (auto callsite : GetNode(c).caller_callsites()) {
    370     callers.push_back(callsite.instruction());
    371   }
    372   return callers;
    373 }
    374 
    375 std::pair<HloInstruction*, HloInstruction*>
    376 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
    377                                              HloInstruction* b) const {
    378   // Lambda which returns the next instruction in the callee->caller chain in
    379   // the call graph. This is the unique instruction which calls the computation
    380   // containing 'instruction'. If more than one instruction calls the
    381   // computation containing 'instruction' or no instructions call the
    382   // computation then nullptr is returned.
    383   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
    384     const CallGraphNode& node = GetNode(instruction->parent());
    385     if (node.caller_callsites().size() != 1) {
    386       return nullptr;
    387     }
    388     return node.caller_callsites()[0].instruction();
    389   };
    390 
    391   // Iterate through the callee->caller chains and find the earliest common
    392   // element.
    393   HloInstruction* a_ancestor = a;
    394   HloInstruction* b_ancestor = b;
    395   int a_depth = GetNode(a->parent()).depth();
    396   int b_depth = GetNode(b->parent()).depth();
    397 
    398   // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
    399   // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
    400   // reduces the depth by exactly one.
    401   if (a_depth > b_depth) {
    402     for (int i = 0; i < a_depth - b_depth; ++i) {
    403       a_ancestor = next_caller(a_ancestor);
    404       if (a_ancestor == nullptr) {
    405         return {nullptr, nullptr};
    406       }
    407     }
    408   } else if (b_depth > a_depth) {
    409     for (int i = 0; i < b_depth - a_depth; ++i) {
    410       b_ancestor = next_caller(b_ancestor);
    411       if (b_ancestor == nullptr) {
    412         return {nullptr, nullptr};
    413       }
    414     }
    415   }
    416 
    417   while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
    418     if (a_ancestor->parent() == b_ancestor->parent()) {
    419       return {a_ancestor, b_ancestor};
    420     }
    421 
    422     a_ancestor = next_caller(a_ancestor);
    423     b_ancestor = next_caller(b_ancestor);
    424   }
    425   return {nullptr, nullptr};
    426 }
    427 
    428 string CallGraph::ToString() const {
    429   string out;
    430   StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
    431   for (const CallGraphNode& node : nodes()) {
    432     StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
    433     StrAppendFormat(&out, "  calls:\n");
    434     for (const HloComputation* callee : node.callees()) {
    435       StrAppendFormat(&out, "    %s\n", callee->name());
    436     }
    437     StrAppendFormat(&out, "  called by:\n");
    438     for (const HloComputation* caller : node.callers()) {
    439       StrAppendFormat(&out, "    %s\n", caller->name());
    440     }
    441     StrAppendFormat(&out, "  callsites:\n");
    442     for (const CallSite& callsite : node.callsites()) {
    443       StrAppendFormat(&out, "    %s\n", callsite.ToString());
    444     }
    445   }
    446   return out;
    447 }
    448 
    449 }  // namespace xla
    450