Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Call graph for an HLO module.
     17 
     18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
     19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
     20 
     21 #include <ostream>
     22 
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/core/lib/gtl/flatmap.h"
     27 #include "tensorflow/core/lib/gtl/flatset.h"
     28 
     29 namespace xla {
     30 
     31 // The context in which a computation is called by another computation.
     32 enum class CallContext {
     33   // In a parallel contex the computation is applied to each element of the
     34   // array argument(s). kMap and kReduce instructions call computations in
     35   // parallel context.
     36   kParallel,
     37 
     38   // In a sequential context the computation is applied to the entire argument
     39   // shape(s). kCall and kWhile (body and condition) call computations in
     40   // sequential context.
     41   kSequential,
     42 
     43   // A computation is called from both a parallel and sequential context.
     44   kBoth,
     45 
     46   // During call graph construction kNone is used to indicate that the context
     47   // has not been determined. This is the top value for the context
     48   // lattice. After construction, no call sites or call graph nodes should have
     49   // this value.
     50   kNone
     51 };
     52 
     53 string CallContextToString(CallContext context);
     54 std::ostream& operator<<(std::ostream& out, const CallContext& context);
     55 
     56 CallContext GetInstructionCallContext(const HloInstruction* instruction);
     57 
     58 // Represents an HLO instruction which calls one or more computations.
     59 class CallSite {
     60  public:
     61   CallSite(HloInstruction* instruction,
     62            const std::vector<HloComputation*>& called_computations,
     63            CallContext context)
     64       : instruction_(CHECK_NOTNULL(instruction)),
     65         called_computations_(called_computations),
     66         context_(context) {}
     67 
     68   // Returns the instruction associated with this call site.
     69   HloInstruction* instruction() const { return instruction_; }
     70 
     71   // Returns the computations called at this call site.
     72   const std::vector<HloComputation*>& called_computations() const {
     73     return called_computations_;
     74   }
     75 
     76   // Returns the context in which computations are called at this call site.
     77   CallContext context() const { return context_; }
     78 
     79   string ToString() const;
     80 
     81  private:
     82   // The calling instruction.
     83   HloInstruction* instruction_;
     84 
     85   // The computations called by this callsite.
     86   const std::vector<HloComputation*> called_computations_;
     87 
     88   // The context in which the computations are called.
     89   const CallContext context_;
     90 };
     91 
     92 // A node in the call graph representing an HLO computation.
     93 class CallGraphNode {
     94  public:
     95   CallGraphNode(HloComputation* computation);
     96 
     97   // Returns the computation represented by this call graph node.
     98   HloComputation* computation() const { return computation_; }
     99 
    100   // Returns the call sites in this computation. These are the instructions in
    101   // this computation which call other computations.
    102   const std::vector<CallSite>& callsites() const { return callsites_; }
    103 
    104   // Returns the callsite associated with the given instruction. If this
    105   // instruction calls no computations nullptr is returned.
    106   // Prerequisite: instruction is in the computation associated with this call
    107   // graph node.
    108   const CallSite* GetCallSite(const HloInstruction* instruction) const;
    109 
    110   // Returns the computations called by this computation.
    111   const std::vector<HloComputation*>& callees() const { return callees_; }
    112 
    113   // Returns the call sites in other computations which call this computation.
    114   const std::vector<CallSite>& caller_callsites() const {
    115     return caller_callsites_;
    116   }
    117 
    118   // Returns the computations which call this computation.
    119   const std::vector<HloComputation*>& callers() const { return callers_; }
    120 
    121   // Returns the context in which this computation is called.
    122   CallContext context() const { return context_; }
    123 
    124   string ToString() const;
    125 
    126  private:
    127   // Only CallGraph can modify CallGraphNode.
    128   friend class CallGraph;
    129 
    130   // Sets the context in which this computation is called.
    131   void set_context(CallContext value) { context_ = value; }
    132 
    133   // Adds a callsite which calls this computation. Updates callers to include
    134   // the calling computation.
    135   void AddCallerCallSite(const CallSite& caller_callsite);
    136 
    137   // If instruction calls any computations adds a call site for this instruction
    138   // to the call graph node. If the instruction calls no computations then no
    139   // call site is added.
    140   void AddCallSiteForInstruction(HloInstruction* instruction);
    141 
    142   // Computation represented by this call graph node.
    143   HloComputation* computation_;
    144 
    145   // The computations called by this computation. The vector is used for a
    146   // stable ordering and the set enables fast membership testing.
    147   std::vector<HloComputation*> callees_;
    148   tensorflow::gtl::FlatSet<HloComputation*> callee_set_;
    149 
    150   // The computations which call this computation. The vector is used for a
    151   // stable ordering and the set enables fast membership testing.
    152   std::vector<HloComputation*> callers_;
    153   tensorflow::gtl::FlatSet<HloComputation*> caller_set_;
    154 
    155   // The call sites in this computation
    156   std::vector<CallSite> callsites_;
    157 
    158   // The map from instruction to index in callsites_ for looking up the callsite
    159   // (if any) associated with a particular instruction in this computation.
    160   tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_;
    161 
    162   // The call sites in other computations which call this computation.
    163   std::vector<CallSite> caller_callsites_;
    164 
    165   // The context in which this computation is called.
    166   CallContext context_ = CallContext::kNone;
    167 };
    168 
    169 // The call graph for an HLO module. The graph includes a node for each
    170 // computation in the module.
    171 class CallGraph {
    172  public:
    173   using VisitorFunction = std::function<Status(const CallGraphNode&)>;
    174 
    175   // Builds and returns a call graph for the given HLO module.
    176   static std::unique_ptr<CallGraph> Build(const HloModule* module);
    177 
    178   // Returns the node associated with the given computation.
    179   const CallGraphNode& GetNode(const HloComputation* computation) const;
    180   CallGraphNode& GetNode(const HloComputation* computation);
    181 
    182   // Returns the vector of all nodes in the call graph.
    183   const std::vector<CallGraphNode>& nodes() const { return nodes_; }
    184 
    185   // Calls the given function on each node in the call graph. Nodes are visited
    186   // in post order (callees before callers). If visit_unreachable_nodes is true
    187   // then all nodes in the call graph are visited. Otherwise only those nodes
    188   // reachable from the entry computation are visited.
    189   Status VisitNodes(const VisitorFunction& visitor_func,
    190                     bool visit_unreachable_nodes = true) const;
    191 
    192   // Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
    193   // dominates computation 'b' iff all callgraph paths in the caller-to-callee
    194   // direction from a root computation to 'b' pass through computation
    195   // 'a'. Trivially, a computation dominates itself.
    196   bool Dominates(const HloComputation* a, const HloComputation* b) const;
    197 
    198   // Returns whether 'instruction' is contained in 'computation' either directly
    199   // ('instruction->parent' is 'computation') or indirectly ('computation'
    200   // dominates 'instruction->parent' in the call graph).
    201   bool InstructionIsNestedIn(const HloInstruction* instruction,
    202                              const HloComputation* computation) const {
    203     return Dominates(computation, instruction->parent());
    204   }
    205 
    206   // Returns the nearest call graph ancestors of instructions 'a' and 'b' for
    207   // which the ancestors are in the same computation. An instruction is an call
    208   // graph ancestor of 'a' if the instruction calls the computation containing
    209   // 'a' either directly or transitively. Degeneratively an instruction is an
    210   // ancestor of itself. nullptr is returned if there is no common ancestor or
    211   // if the caller chain of 'a' or 'b' diverges (has multiple callers) before
    212   // the nearest common ancestor.
    213   //
    214   // Example:
    215   //
    216   // Entry computation:
    217   //   %x = Call(A, {Constant(42.0)})
    218   //   %y = Call(B, {%x})
    219   //
    220   // Computation A:
    221   //   %a = Negate(Param())
    222   //
    223   // Computation B:
    224   //   %b = Exp(Param());
    225   //
    226   // If called with %a and %b, this function would return (%x, %y). %x is an
    227   // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same
    228   // computation.
    229   std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation(
    230       HloInstruction* a, HloInstruction* b) const;
    231 
    232   // Returns whether the call graph is flattened. A call graph is flattened if
    233   // every computation called in a sequential context (eg, kWhile or kCall) has
    234   // zero or one callsite, and no computation is called from both a parallel and
    235   // sequential context. The call graph of a module can be flattened with
    236   // FlattenCallGraph.
    237   bool IsFlattened() const;
    238 
    239   string ToString() const;
    240 
    241  private:
    242   CallGraph(const HloModule* module);
    243 
    244   // Sets the call contexts for every node in the graph.
    245   void SetCallContexts();
    246 
    247   // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS
    248   // post order (callee before caller) calling visitor_func on each node. Adds
    249   // nodes to 'visited' as each node is visited. Skips nodes already in
    250   // 'visited'.
    251   Status VisitNodesInternal(
    252       const VisitorFunction& visitor_func, const CallGraphNode& node,
    253       tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
    254 
    255   // Recursive helper for computing whether 'a' dominates 'b' in the call
    256   // graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
    257   // and 'visited' is the set of computations which have been visited.
    258   bool DominatesHelper(
    259       const HloComputation* a, const HloComputation* b,
    260       tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
    261 
    262   // The HLO module represented by this call graph.
    263   const HloModule* module_ = nullptr;
    264 
    265   // Vector of all nodes in the call graph.
    266   std::vector<CallGraphNode> nodes_;
    267 
    268   // Map from HLO computation to the index of the corresponding call graph node
    269   // in nodes_.
    270   tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_;
    271 };
    272 
    273 }  // namespace xla
    274 
    275 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
    276