Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/call_graph.h"
     17 
     18 #include <queue>
     19 
     20 #include "tensorflow/compiler/xla/map_util.h"
     21 #include "tensorflow/compiler/xla/ptr_util.h"
     22 #include "tensorflow/compiler/xla/status_macros.h"
     23 #include "tensorflow/compiler/xla/util.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/strings/str_util.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/lib/strings/stringprintf.h"
     29 #include "tensorflow/core/platform/types.h"
     30 
     31 namespace xla {
     32 
     33 using ::tensorflow::strings::Appendf;
     34 using ::tensorflow::strings::StrCat;
     35 
     36 string CallContextToString(CallContext context) {
     37   switch (context) {
     38     case CallContext::kNone:
     39       return "kNone";
     40     case CallContext::kSequential:
     41       return "kSequential";
     42     case CallContext::kParallel:
     43       return "kParallel";
     44     case CallContext::kBoth:
     45       return "kBoth";
     46   }
     47 }
     48 
     49 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
     50   out << CallContextToString(context);
     51   return out;
     52 }
     53 
     54 CallContext GetInstructionCallContext(const HloInstruction* instruction) {
     55   switch (instruction->opcode()) {
     56     case HloOpcode::kCall:
     57     case HloOpcode::kConditional:
     58     case HloOpcode::kWhile:
     59       return CallContext::kSequential;
     60     case HloOpcode::kMap:
     61     case HloOpcode::kReduce:
     62     case HloOpcode::kReduceWindow:
     63     case HloOpcode::kSelectAndScatter:
     64     case HloOpcode::kFusion:
     65       return CallContext::kParallel;
     66     default:
     67       return CallContext::kNone;
     68   }
     69 }
     70 
     71 string CallSite::ToString() const {
     72   return StrCat(instruction()->name(), " calls in context ",
     73                 CallContextToString(context()), ": ",
     74                 tensorflow::str_util::Join(
     75                     called_computations(), ", ",
     76                     [](string* out, const HloComputation* computation) {
     77                       out->append(computation->name());
     78                     }));
     79 }
     80 
     81 CallGraphNode::CallGraphNode(HloComputation* computation)
     82     : computation_(computation) {}
     83 
     84 const CallSite* CallGraphNode::GetCallSite(
     85     const HloInstruction* instruction) const {
     86   auto it = callsite_instructions_.find(instruction);
     87   if (it == callsite_instructions_.end()) {
     88     return nullptr;
     89   }
     90   return &callsites_[it->second];
     91 }
     92 
     93 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
     94   caller_callsites_.push_back(caller_callsite);
     95   HloComputation* caller = caller_callsite.instruction()->parent();
     96   if (!ContainsKey(caller_set_, caller)) {
     97     callers_.push_back(caller);
     98     caller_set_.insert(caller);
     99   }
    100 }
    101 
    102 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
    103   CHECK_EQ(instruction->parent(), computation());
    104   const CallContext context = GetInstructionCallContext(instruction);
    105   if (!instruction->called_computations().empty()) {
    106     CHECK(context == CallContext::kSequential ||
    107           context == CallContext::kParallel);
    108     callsite_instructions_.insert({instruction, callsites_.size()});
    109     callsites_.push_back(
    110         CallSite(instruction, instruction->called_computations(), context));
    111     // Update callee computations to include any new computations called by this
    112     // instruction.
    113     for (auto* callee : callsites_.back().called_computations()) {
    114       if (!ContainsKey(callee_set_, callee)) {
    115         callees_.push_back(callee);
    116         callee_set_.insert(callee);
    117       }
    118     }
    119   }
    120 }
    121 
    122 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
    123 
    124 const CallGraphNode& CallGraph::GetNode(
    125     const HloComputation* computation) const {
    126   auto it = node_indices_.find(computation);
    127   CHECK(it != node_indices_.end());
    128   return nodes_[it->second];
    129 }
    130 
    131 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
    132   auto it = node_indices_.find(computation);
    133   CHECK(it != node_indices_.end());
    134   return nodes_[it->second];
    135 }
    136 
    137 bool CallGraph::DominatesHelper(
    138     const HloComputation* a, const HloComputation* b,
    139     tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
    140   if (a == b || ContainsKey(*visited, b)) {
    141     // The call graph is guaranteed to be acyclic so any previously visited node
    142     // we encounter was already determined to be dominated.
    143     return true;
    144   }
    145 
    146   const CallGraphNode& b_node = GetNode(b);
    147   if (b_node.callers().empty()) {
    148     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
    149     return false;
    150   }
    151 
    152   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
    153   visited->insert(b);
    154   for (const HloComputation* b_caller : b_node.callers()) {
    155     if (!DominatesHelper(a, b_caller, visited)) {
    156       return false;
    157     }
    158   }
    159   return true;
    160 }
    161 
    162 bool CallGraph::Dominates(const HloComputation* a,
    163                           const HloComputation* b) const {
    164   tensorflow::gtl::FlatSet<const HloComputation*> visited;
    165   return DominatesHelper(a, b, &visited);
    166 }
    167 
    168 namespace {
    169 
    170 // Returns the call context of a computation which is called from contexts 'a'
    171 // and 'b'.
    172 CallContext UnionContexts(CallContext a, CallContext b) {
    173   if (a == CallContext::kNone) {
    174     return b;
    175   } else if (b == CallContext::kNone) {
    176     return a;
    177   } else if (a == b) {
    178     return a;
    179   } else {
    180     // Contexts are different and neither is kNone, ie one is kSequential and
    181     // the other is kParallel.
    182     return CallContext::kBoth;
    183   }
    184 }
    185 
    186 }  // namespace
    187 
    188 void CallGraph::SetCallContexts() {
    189   std::queue<CallGraphNode*> worklist;
    190 
    191   // Initialize worklist with all roots of the call graph (computations without
    192   // callers).
    193   for (const HloComputation* computation : module_->computations()) {
    194     CallGraphNode& node = GetNode(computation);
    195     if (node.callers().empty()) {
    196       node.set_context(CallContext::kSequential);
    197       worklist.push(&node);
    198     }
    199   }
    200 
    201   while (!worklist.empty()) {
    202     CallGraphNode* node = worklist.front();
    203     worklist.pop();
    204 
    205     for (const CallSite& callsite : node->callsites()) {
    206       for (const HloComputation* callee : callsite.called_computations()) {
    207         CallGraphNode& callee_node = GetNode(callee);
    208 
    209         // Update context of callee computation based on the callsite and its
    210         // current context.
    211         CallContext context_to_add;
    212         if (callsite.context() == CallContext::kParallel) {
    213           context_to_add = CallContext::kParallel;
    214         } else {
    215           CHECK_EQ(callsite.context(), CallContext::kSequential);
    216           context_to_add = node->context();
    217         }
    218         CallContext new_context =
    219             UnionContexts(context_to_add, callee_node.context());
    220 
    221         if (new_context != callee_node.context()) {
    222           // Context of computation has been changed so add node to worklist.
    223           callee_node.set_context(new_context);
    224           worklist.push(&callee_node);
    225         }
    226       }
    227     }
    228   }
    229 
    230   // No node should have a kNone calling context.
    231   for (const HloComputation* computation : module_->computations()) {
    232     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
    233   }
    234 }
    235 
    236 /* static */
    237 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
    238   // Constructor for CallGraph is private so MakeUnique can't be used.
    239   auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
    240 
    241   VLOG(2) << "Building call graph for:";
    242   XLA_VLOG_LINES(2, module->ToString());
    243 
    244   // Construct nodes of the call graph and populate the callsites.
    245   for (HloComputation* computation : module->computations()) {
    246     auto it_added = call_graph->node_indices_.insert(
    247         {computation, call_graph->nodes_.size()});
    248     // All computations should be unique, so the computation should not already
    249     // exist in the map.
    250     CHECK(it_added.second);
    251     call_graph->nodes_.emplace_back(computation);
    252 
    253     // Add all callsites in this computation.
    254     for (HloInstruction* instruction : computation->instructions()) {
    255       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
    256     }
    257   }
    258 
    259   // Add caller callsites to each node.
    260   for (const HloComputation* computation : module->computations()) {
    261     for (const CallSite& callsite :
    262          call_graph->GetNode(computation).callsites()) {
    263       for (auto* callee : callsite.called_computations()) {
    264         // Add caller callsites.
    265         call_graph->GetNode(callee).AddCallerCallSite(callsite);
    266       }
    267     }
    268   }
    269 
    270   call_graph->SetCallContexts();
    271   XLA_VLOG_LINES(1, call_graph->ToString());
    272 
    273   return call_graph;
    274 }
    275 
    276 Status CallGraph::VisitNodesInternal(
    277     const VisitorFunction& visitor_func, const CallGraphNode& node,
    278     tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
    279   auto pair = visited->insert(&node);
    280   if (!pair.second) {
    281     // Node was not inserted. Node has already been visited.
    282     return Status::OK();
    283   }
    284 
    285   for (const HloComputation* computation : node.callees()) {
    286     TF_RETURN_IF_ERROR(
    287         VisitNodesInternal(visitor_func, GetNode(computation), visited));
    288   }
    289 
    290   return visitor_func(node);
    291 }
    292 
    293 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
    294                              bool visit_unreachable_nodes) const {
    295   tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
    296   if (visit_unreachable_nodes) {
    297     // Traverse from all roots in the call graph.
    298     for (const CallGraphNode& node : nodes()) {
    299       if (node.callers().empty()) {
    300         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
    301       }
    302     }
    303   } else {
    304     // Traverse only from the entry computation.
    305     TF_RETURN_IF_ERROR(VisitNodesInternal(
    306         visitor_func, GetNode(module_->entry_computation()), &visited));
    307   }
    308 
    309   return Status::OK();
    310 }
    311 
    312 bool CallGraph::IsFlattened() const {
    313   for (const CallGraphNode& node : nodes_) {
    314     if (node.context() == CallContext::kBoth) {
    315       return false;
    316     }
    317     if (node.context() == CallContext::kSequential &&
    318         node.caller_callsites().size() > 1) {
    319       return false;
    320     }
    321   }
    322   return true;
    323 }
    324 
    325 std::pair<HloInstruction*, HloInstruction*>
    326 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
    327                                              HloInstruction* b) const {
    328   // Lambda which returns the next instruction in the callee->caller chain in
    329   // the call graph. This is the unique instruction which calls the computation
    330   // containing 'instruction'. If more than one instruction calls the
    331   // computation containing 'instruction' or no instructions call the
    332   // computation then nullptr is returned.
    333   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
    334     const CallGraphNode& node = GetNode(instruction->parent());
    335     if (node.caller_callsites().size() != 1) {
    336       return nullptr;
    337     }
    338     return node.caller_callsites()[0].instruction();
    339   };
    340 
    341   // Iterate through the callee->caller chains and find the earliest common
    342   // element.
    343   for (HloInstruction* a_ancestor = a; a_ancestor != nullptr;
    344        a_ancestor = next_caller(a_ancestor)) {
    345     for (HloInstruction* b_ancestor = b; b_ancestor != nullptr;
    346          b_ancestor = next_caller(b_ancestor)) {
    347       if (a_ancestor->parent() == b_ancestor->parent()) {
    348         return {a_ancestor, b_ancestor};
    349       }
    350     }
    351   }
    352   return {nullptr, nullptr};
    353 }
    354 
    355 string CallGraph::ToString() const {
    356   string out;
    357   Appendf(&out, "Call graph for module %s:\n", module_->name().c_str());
    358   for (const CallGraphNode& node : nodes()) {
    359     Appendf(&out, "Computation %s:\n", node.computation()->name().c_str());
    360     Appendf(&out, "  calls:\n");
    361     for (const HloComputation* callee : node.callees()) {
    362       Appendf(&out, "    %s\n", callee->name().c_str());
    363     }
    364     Appendf(&out, "  called by:\n");
    365     for (const HloComputation* caller : node.callers()) {
    366       Appendf(&out, "    %s\n", caller->name().c_str());
    367     }
    368     Appendf(&out, "  callsites:\n");
    369     for (const CallSite& callsite : node.callsites()) {
    370       Appendf(&out, "    %s\n", callsite.ToString().c_str());
    371     }
    372   }
    373   return out;
    374 }
    375 
    376 }  // namespace xla
    377