1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Call graph for an HLO module. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ 20 21 #include <ostream> 22 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/core/lib/gtl/flatmap.h" 27 #include "tensorflow/core/lib/gtl/flatset.h" 28 29 namespace xla { 30 31 // The context in which a computation is called by another computation. 32 enum class CallContext { 33 // In a parallel contex the computation is applied to each element of the 34 // array argument(s). kMap and kReduce instructions call computations in 35 // parallel context. 36 kParallel, 37 38 // In a sequential context the computation is applied to the entire argument 39 // shape(s). kCall and kWhile (body and condition) call computations in 40 // sequential context. 41 kSequential, 42 43 // A computation is called from both a parallel and sequential context. 44 kBoth, 45 46 // During call graph construction kNone is used to indicate that the context 47 // has not been determined. This is the top value for the context 48 // lattice. After construction, no call sites or call graph nodes should have 49 // this value. 50 kNone 51 }; 52 53 string CallContextToString(CallContext context); 54 std::ostream& operator<<(std::ostream& out, const CallContext& context); 55 56 CallContext GetInstructionCallContext(const HloInstruction* instruction); 57 58 // Represents an HLO instruction which calls one or more computations. 59 class CallSite { 60 public: 61 CallSite(HloInstruction* instruction, 62 const std::vector<HloComputation*>& called_computations, 63 CallContext context) 64 : instruction_(CHECK_NOTNULL(instruction)), 65 called_computations_(called_computations), 66 context_(context) {} 67 68 // Returns the instruction associated with this call site. 69 HloInstruction* instruction() const { return instruction_; } 70 71 // Returns the computations called at this call site. 72 const std::vector<HloComputation*>& called_computations() const { 73 return called_computations_; 74 } 75 76 // Returns the context in which computations are called at this call site. 77 CallContext context() const { return context_; } 78 79 string ToString() const; 80 81 private: 82 // The calling instruction. 83 HloInstruction* instruction_; 84 85 // The computations called by this callsite. 86 const std::vector<HloComputation*> called_computations_; 87 88 // The context in which the computations are called. 89 const CallContext context_; 90 }; 91 92 // A node in the call graph representing an HLO computation. 93 class CallGraphNode { 94 public: 95 CallGraphNode(HloComputation* computation); 96 97 // Returns the computation represented by this call graph node. 98 HloComputation* computation() const { return computation_; } 99 100 // Returns the call sites in this computation. These are the instructions in 101 // this computation which call other computations. 102 const std::vector<CallSite>& callsites() const { return callsites_; } 103 104 // Returns the callsite associated with the given instruction. If this 105 // instruction calls no computations nullptr is returned. 106 // Prerequisite: instruction is in the computation associated with this call 107 // graph node. 108 const CallSite* GetCallSite(const HloInstruction* instruction) const; 109 110 // Returns the computations called by this computation. 111 const std::vector<HloComputation*>& callees() const { return callees_; } 112 113 // Returns the call sites in other computations which call this computation. 114 const std::vector<CallSite>& caller_callsites() const { 115 return caller_callsites_; 116 } 117 118 // Returns the computations which call this computation. 119 const std::vector<HloComputation*>& callers() const { return callers_; } 120 121 // Returns the context in which this computation is called. 122 CallContext context() const { return context_; } 123 124 string ToString() const; 125 126 private: 127 // Only CallGraph can modify CallGraphNode. 128 friend class CallGraph; 129 130 // Sets the context in which this computation is called. 131 void set_context(CallContext value) { context_ = value; } 132 133 // Adds a callsite which calls this computation. Updates callers to include 134 // the calling computation. 135 void AddCallerCallSite(const CallSite& caller_callsite); 136 137 // If instruction calls any computations adds a call site for this instruction 138 // to the call graph node. If the instruction calls no computations then no 139 // call site is added. 140 void AddCallSiteForInstruction(HloInstruction* instruction); 141 142 // Computation represented by this call graph node. 143 HloComputation* computation_; 144 145 // The computations called by this computation. The vector is used for a 146 // stable ordering and the set enables fast membership testing. 147 std::vector<HloComputation*> callees_; 148 tensorflow::gtl::FlatSet<HloComputation*> callee_set_; 149 150 // The computations which call this computation. The vector is used for a 151 // stable ordering and the set enables fast membership testing. 152 std::vector<HloComputation*> callers_; 153 tensorflow::gtl::FlatSet<HloComputation*> caller_set_; 154 155 // The call sites in this computation 156 std::vector<CallSite> callsites_; 157 158 // The map from instruction to index in callsites_ for looking up the callsite 159 // (if any) associated with a particular instruction in this computation. 160 tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_; 161 162 // The call sites in other computations which call this computation. 163 std::vector<CallSite> caller_callsites_; 164 165 // The context in which this computation is called. 166 CallContext context_ = CallContext::kNone; 167 }; 168 169 // The call graph for an HLO module. The graph includes a node for each 170 // computation in the module. 171 class CallGraph { 172 public: 173 using VisitorFunction = std::function<Status(const CallGraphNode&)>; 174 175 // Builds and returns a call graph for the given HLO module. 176 static std::unique_ptr<CallGraph> Build(const HloModule* module); 177 178 // Returns the node associated with the given computation. 179 const CallGraphNode& GetNode(const HloComputation* computation) const; 180 CallGraphNode& GetNode(const HloComputation* computation); 181 182 // Returns the vector of all nodes in the call graph. 183 const std::vector<CallGraphNode>& nodes() const { return nodes_; } 184 185 // Calls the given function on each node in the call graph. Nodes are visited 186 // in post order (callees before callers). If visit_unreachable_nodes is true 187 // then all nodes in the call graph are visited. Otherwise only those nodes 188 // reachable from the entry computation are visited. 189 Status VisitNodes(const VisitorFunction& visitor_func, 190 bool visit_unreachable_nodes = true) const; 191 192 // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' 193 // dominates computation 'b' iff all callgraph paths in the caller-to-callee 194 // direction from a root computation to 'b' pass through computation 195 // 'a'. Trivially, a computation dominates itself. 196 bool Dominates(const HloComputation* a, const HloComputation* b) const; 197 198 // Returns whether 'instruction' is contained in 'computation' either directly 199 // ('instruction->parent' is 'computation') or indirectly ('computation' 200 // dominates 'instruction->parent' in the call graph). 201 bool InstructionIsNestedIn(const HloInstruction* instruction, 202 const HloComputation* computation) const { 203 return Dominates(computation, instruction->parent()); 204 } 205 206 // Returns the nearest call graph ancestors of instructions 'a' and 'b' for 207 // which the ancestors are in the same computation. An instruction is an call 208 // graph ancestor of 'a' if the instruction calls the computation containing 209 // 'a' either directly or transitively. Degeneratively an instruction is an 210 // ancestor of itself. nullptr is returned if there is no common ancestor or 211 // if the caller chain of 'a' or 'b' diverges (has multiple callers) before 212 // the nearest common ancestor. 213 // 214 // Example: 215 // 216 // Entry computation: 217 // %x = Call(A, {Constant(42.0)}) 218 // %y = Call(B, {%x}) 219 // 220 // Computation A: 221 // %a = Negate(Param()) 222 // 223 // Computation B: 224 // %b = Exp(Param()); 225 // 226 // If called with %a and %b, this function would return (%x, %y). %x is an 227 // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same 228 // computation. 229 std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation( 230 HloInstruction* a, HloInstruction* b) const; 231 232 // Returns whether the call graph is flattened. A call graph is flattened if 233 // every computation called in a sequential context (eg, kWhile or kCall) has 234 // zero or one callsite, and no computation is called from both a parallel and 235 // sequential context. The call graph of a module can be flattened with 236 // FlattenCallGraph. 237 bool IsFlattened() const; 238 239 string ToString() const; 240 241 private: 242 CallGraph(const HloModule* module); 243 244 // Sets the call contexts for every node in the graph. 245 void SetCallContexts(); 246 247 // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS 248 // post order (callee before caller) calling visitor_func on each node. Adds 249 // nodes to 'visited' as each node is visited. Skips nodes already in 250 // 'visited'. 251 Status VisitNodesInternal( 252 const VisitorFunction& visitor_func, const CallGraphNode& node, 253 tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const; 254 255 // Recursive helper for computing whether 'a' dominates 'b' in the call 256 // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), 257 // and 'visited' is the set of computations which have been visited. 258 bool DominatesHelper( 259 const HloComputation* a, const HloComputation* b, 260 tensorflow::gtl::FlatSet<const HloComputation*>* visited) const; 261 262 // The HLO module represented by this call graph. 263 const HloModule* module_ = nullptr; 264 265 // Vector of all nodes in the call graph. 266 std::vector<CallGraphNode> nodes_; 267 268 // Map from HLO computation to the index of the corresponding call graph node 269 // in nodes_. 270 tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_; 271 }; 272 273 } // namespace xla 274 275 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ 276