Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Call graph for an HLO module.
     17 
     18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
     19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
     20 
     21 #include <ostream>
     22 
     23 #include "absl/container/flat_hash_map.h"
     24 #include "absl/container/flat_hash_set.h"
     25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     27 #include "tensorflow/compiler/xla/service/hlo_module.h"
     28 
     29 namespace xla {
     30 
     31 // The context in which a computation is called by another computation.
     32 enum class CallContext {
     33   // In a parallel context the computation is applied to each element of the
     34   // array argument(s). kMap and kReduce instructions call computations in
     35   // parallel context.
     36   kParallel,
     37 
     38   // In a sequential context the computation is applied to the entire argument
     39   // shape(s). kCall and kWhile (body and condition) call computations in
     40   // sequential context.
     41   kSequential,
     42 
     43   // A computation is called from both a parallel and sequential context.
     44   kBoth,
     45 
     46   // During call graph construction kNone is used to indicate that the context
     47   // has not been determined. This is the top value for the context
     48   // lattice. After construction, no call sites or call graph nodes should have
     49   // this value.
     50   kNone
     51 };
     52 
     53 string CallContextToString(CallContext context);
     54 std::ostream& operator<<(std::ostream& out, const CallContext& context);
     55 
     56 CallContext GetInstructionCallContext(HloOpcode opcode);
     57 
     58 // Represents an HLO instruction which calls one or more computations.
     59 class CallSite {
     60  public:
     61   CallSite(HloInstruction* instruction,
     62            const std::vector<HloComputation*>& called_computations,
     63            CallContext context)
     64       : instruction_(CHECK_NOTNULL(instruction)),
     65         called_computations_(called_computations),
     66         context_(context) {}
     67 
     68   // Returns the instruction associated with this call site.
     69   HloInstruction* instruction() const { return instruction_; }
     70 
     71   // Returns the computations called at this call site.
     72   const std::vector<HloComputation*>& called_computations() const {
     73     return called_computations_;
     74   }
     75 
     76   // Returns the context in which computations are called at this call site.
     77   CallContext context() const { return context_; }
     78 
     79   string ToString() const;
     80 
     81  private:
     82   // The calling instruction.
     83   HloInstruction* instruction_;
     84 
     85   // The computations called by this callsite.
     86   const std::vector<HloComputation*> called_computations_;
     87 
     88   // The context in which the computations are called.
     89   const CallContext context_;
     90 };
     91 
     92 // A node in the call graph representing an HLO computation.
     93 class CallGraphNode {
     94  public:
     95   CallGraphNode(HloComputation* computation);
     96 
     97   // Returns the computation represented by this call graph node.
     98   HloComputation* computation() const { return computation_; }
     99 
    100   // Returns the call sites in this computation. These are the instructions in
    101   // this computation which call other computations.
    102   const std::vector<CallSite>& callsites() const { return callsites_; }
    103 
    104   // Returns the callsite associated with the given instruction. If this
    105   // instruction calls no computations nullptr is returned.
    106   // Prerequisite: instruction is in the computation associated with this call
    107   // graph node.
    108   const CallSite* GetCallSite(const HloInstruction* instruction) const;
    109 
    110   // Returns the computations called by this computation.
    111   const std::vector<HloComputation*>& callees() const { return callees_; }
    112 
    113   // Returns the call sites in other computations which call this computation.
    114   const std::vector<CallSite>& caller_callsites() const {
    115     return caller_callsites_;
    116   }
    117 
    118   // Returns the computations which call this computation.
    119   const std::vector<HloComputation*>& callers() const { return callers_; }
    120 
    121   // Returns the context in which this computation is called.
    122   CallContext context() const { return context_; }
    123 
    124   // Returns the depth of this node in the call graph. The depth is defined as
    125   // the length of the longest call chain from a computation with no callers
    126   // (usually the entry computation node) to this node.
    127   int depth() const { return depth_; }
    128 
    129   string ToString() const;
    130 
    131  private:
    132   // Only CallGraph can modify CallGraphNode.
    133   friend class CallGraph;
    134 
    135   // Sets the context in which this computation is called.
    136   void set_context(CallContext value) { context_ = value; }
    137 
    138   // Sets the depth of this node in the graph.
    139   void set_depth(int value) { depth_ = value; }
    140 
    141   // Adds a callsite which calls this computation. Updates callers to include
    142   // the calling computation.
    143   void AddCallerCallSite(const CallSite& caller_callsite);
    144 
    145   // If instruction calls any computations adds a call site for this instruction
    146   // to the call graph node. If the instruction calls no computations then no
    147   // call site is added.
    148   void AddCallSiteForInstruction(HloInstruction* instruction);
    149 
    150   // Computation represented by this call graph node.
    151   HloComputation* computation_;
    152 
    153   // The computations called by this computation. The vector is used for a
    154   // stable ordering and the set enables fast membership testing.
    155   std::vector<HloComputation*> callees_;
    156   absl::flat_hash_set<HloComputation*> callee_set_;
    157 
    158   // The computations which call this computation. The vector is used for a
    159   // stable ordering and the set enables fast membership testing.
    160   std::vector<HloComputation*> callers_;
    161   absl::flat_hash_set<HloComputation*> caller_set_;
    162 
    163   // The call sites in this computation
    164   std::vector<CallSite> callsites_;
    165 
    166   // The map from instruction to index in callsites_ for looking up the callsite
    167   // (if any) associated with a particular instruction in this computation.
    168   absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_;
    169 
    170   // The call sites in other computations which call this computation.
    171   std::vector<CallSite> caller_callsites_;
    172 
    173   // The context in which this computation is called.
    174   CallContext context_ = CallContext::kNone;
    175 
    176   // The depth of this node in the call graph.
    177   int depth_ = 0;
    178 };
    179 
    180 // The call graph for an HLO module. The graph includes a node for each
    181 // computation in the module.
    182 class CallGraph {
    183  public:
    184   using VisitorFunction = std::function<Status(const CallGraphNode&)>;
    185 
    186   // Builds and returns a call graph for the given HLO module.
    187   static std::unique_ptr<CallGraph> Build(const HloModule* module);
    188 
    189   // Returns the node associated with the given computation.
    190   const CallGraphNode& GetNode(const HloComputation* computation) const;
    191   CallGraphNode& GetNode(const HloComputation* computation);
    192 
    193   // Returns the vector of all nodes in the call graph.
    194   const std::vector<CallGraphNode>& nodes() const { return nodes_; }
    195 
    196   // Calls the given function on each node in the call graph. Nodes are visited
    197   // in post order (callees before callers). If visit_unreachable_nodes is true
    198   // then all nodes in the call graph are visited. Otherwise only those nodes
    199   // reachable from the entry computation are visited.
    200   Status VisitNodes(const VisitorFunction& visitor_func,
    201                     bool visit_unreachable_nodes = true) const;
    202 
    203   // Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
    204   // dominates computation 'b' iff all callgraph paths in the caller-to-callee
    205   // direction from a root computation to 'b' pass through computation
    206   // 'a'. Trivially, a computation dominates itself.
    207   bool Dominates(const HloComputation* a, const HloComputation* b) const;
    208 
    209   // Returns whether 'instruction' is contained in 'computation' either directly
    210   // ('instruction->parent' is 'computation') or indirectly ('computation'
    211   // dominates 'instruction->parent' in the call graph).
    212   bool InstructionIsNestedIn(const HloInstruction* instruction,
    213                              const HloComputation* computation) const {
    214     return Dominates(computation, instruction->parent());
    215   }
    216 
    217   // Returns the nearest call graph ancestors of instructions 'a' and 'b' for
    218   // which the ancestors are in the same computation. An instruction is an call
    219   // graph ancestor of 'a' if the instruction calls the computation containing
    220   // 'a' either directly or transitively. Degeneratively an instruction is an
    221   // ancestor of itself. nullptr is returned if there is no common ancestor or
    222   // if the caller chain of 'a' or 'b' diverges (has multiple callers) before
    223   // the nearest common ancestor.
    224   //
    225   // Example:
    226   //
    227   // Entry computation:
    228   //   %x = Call(A, {Constant(42.0)})
    229   //   %y = Call(B, {%x})
    230   //
    231   // Computation A:
    232   //   %a = Negate(Param())
    233   //
    234   // Computation B:
    235   //   %b = Exp(Param());
    236   //
    237   // If called with %a and %b, this function would return (%x, %y). %x is an
    238   // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same
    239   // computation.
    240   std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation(
    241       HloInstruction* a, HloInstruction* b) const;
    242 
    243   // Returns whether the call graph is flattened. A call graph is flattened if
    244   // every computation called in a sequential context (eg, kWhile or kCall) has
    245   // zero or one callsite, and no computation is called from both a parallel and
    246   // sequential context. The call graph of a module can be flattened with
    247   // FlattenCallGraph.
    248   bool IsFlattened() const;
    249 
    250   // Returns a vector of instructions calling the passed computation.
    251   // (Often a vector of size 1.)
    252   std::vector<HloInstruction*> GetComputationCallers(HloComputation* c);
    253 
    254   string ToString() const;
    255 
    256  private:
    257   CallGraph(const HloModule* module);
    258 
    259   // Not copyable.
    260   CallGraph(const CallGraph&) = delete;
    261   CallGraph& operator=(const CallGraph&) = delete;
    262 
    263   // Sets the call contexts for every node in the graph.
    264   void SetCallContexts();
    265 
    266   // Sets the call node depths for every node in the graph.
    267   void SetNodeDepths();
    268 
    269   // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS
    270   // post order (callee before caller) calling visitor_func on each node. Adds
    271   // nodes to 'visited' as each node is visited. Skips nodes already in
    272   // 'visited'.
    273   Status VisitNodesInternal(
    274       const VisitorFunction& visitor_func, const CallGraphNode& node,
    275       absl::flat_hash_set<const CallGraphNode*>* visited) const;
    276 
    277   // Recursive helper for computing whether 'a' dominates 'b' in the call
    278   // graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
    279   // and 'visited' is the set of computations which have been visited.
    280   bool DominatesHelper(
    281       const HloComputation* a, const HloComputation* b,
    282       absl::flat_hash_set<const HloComputation*>* visited) const;
    283 
    284   // The HLO module represented by this call graph.
    285   const HloModule* module_ = nullptr;
    286 
    287   // Vector of all nodes in the call graph.
    288   std::vector<CallGraphNode> nodes_;
    289 
    290   // Map from HLO computation to the index of the corresponding call graph node
    291   // in nodes_.
    292   absl::flat_hash_map<const HloComputation*, int64> node_indices_;
    293 };
    294 
    295 }  // namespace xla
    296 
    297 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
    298