1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 18 19 #include <list> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "tensorflow/compiler/xla/iterator_util.h" 28 #include "tensorflow/compiler/xla/map_util.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 31 #include "tensorflow/compiler/xla/service/hlo.pb.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 34 #include "tensorflow/compiler/xla/service/name_uniquer.h" 35 #include "tensorflow/compiler/xla/shape_tree.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/lib/gtl/array_slice.h" 41 #include "tensorflow/core/lib/gtl/flatmap.h" 42 #include "tensorflow/core/lib/gtl/flatset.h" 43 #include "tensorflow/core/platform/macros.h" 44 #include "tensorflow/core/platform/types.h" 45 46 namespace xla { 47 48 class HloModule; 49 50 // Describes a computation at the HLO level. 51 // 52 // An HloComputation contains a directed acyclic graph of HLO instructions. The 53 // computation has a single root instruction which produces the output of the 54 // computation. 55 class HloComputation { 56 public: 57 // Builder class for HloComputation. 58 class Builder { 59 public: 60 explicit Builder(const string& name, 61 HloInstruction* fusion_instruction = nullptr) 62 : name_(name), 63 last_added_instruction_(nullptr), 64 fusion_instruction_(fusion_instruction) {} 65 66 // Build and return an HloComputation. The parameter root_instruction 67 // specifies the already-added instruction to use as the root. If 68 // root_instruction is nullptr then use the last added instruction as the 69 // root. 70 std::unique_ptr<HloComputation> Build( 71 HloInstruction* root_instruction = nullptr); 72 73 HloInstruction* AddInstruction( 74 std::unique_ptr<HloInstruction> instruction) { 75 instructions_.push_back(std::move(instruction)); 76 last_added_instruction_ = instructions_.back().get(); 77 return last_added_instruction_; 78 } 79 80 Status ForEachInstruction( 81 const std::function<Status(const HloInstruction*)>& func) const { 82 for (const auto& instruction : instructions_) { 83 TF_RETURN_IF_ERROR(func(instruction.get())); 84 } 85 return Status::OK(); 86 } 87 88 private: 89 const string name_; 90 HloInstruction* last_added_instruction_; 91 HloInstruction* fusion_instruction_; 92 std::vector<std::unique_ptr<HloInstruction>> instructions_; 93 }; 94 95 // Add an instruction to the computation. The computation takes ownership of 96 // the instruction. 97 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction); 98 99 // Remove the param_no'th parameter from the computation. 100 // Note this is only applicatable to the computation for the fusion 101 // instruction. 102 Status RemoveParameter(int64 param_no); 103 104 // Add new parameter instruction to the computation. 105 // This should be a new parameter. Instruction will be appended to parameters 106 // and inserted to the instruction list. 107 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction); 108 109 // Remove an instruction from the computation. The instruction must have no 110 // users. Instruction is deallocated with this call. 111 Status RemoveInstruction(HloInstruction* instruction); 112 113 // Remove an instruction from the computation and also transitively any 114 // operand that has no users post removing an instruction. The instruction 115 // must have no users. Instruction is deallocated with this call. 116 Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); 117 118 // Set the root of the computation to the given instruction. The instruction 119 // must have already been added to the computation and have the same shape as 120 // the result of the computation for non fusion computations. 121 void set_root_instruction(HloInstruction* new_root_instruction); 122 123 // Return the root instruction of the computation. The root instruction is the 124 // instruction which produces the output of the computation. 125 HloInstruction* root_instruction() const { return root_instruction_; } 126 127 // Returns the number of parameters for this computation. 128 int64 num_parameters() const { return param_instructions_.size(); } 129 130 // Returns the parameter instruction for the given parameter number. 131 HloInstruction* parameter_instruction(int64 param_no) const { 132 CHECK_GE(param_no, 0); 133 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size())) 134 << "Computation " << name() << " has no parameter number " << param_no; 135 return param_instructions_[param_no]; 136 } 137 138 const std::vector<HloInstruction*>& parameter_instructions() const { 139 return param_instructions_; 140 } 141 142 const string& name() const { return name_; } 143 144 // Use the given NameUniquer to select a unique name for the computation based 145 // on the computation's existing name. 146 void UniquifyName(NameUniquer* name_uniquer); 147 148 // Return a string representation of the computation. 149 // 150 // (We express the default options using an overload rather than a default 151 // param because gdb ignores default params, but does resolve overloads.) 152 string ToString() const { return ToString(HloPrintOptions()); } 153 string ToString(const HloPrintOptions& options) const; 154 155 // Returns a serialized representation of this computation. 156 HloComputationProto ToProto() const; 157 158 // Creates a computation from the given proto. Arguments: 159 // 160 // module: the module which will contain the computation. The newly created 161 // computation is *not* added to the module, however. 162 // proto: the proto to convert from. 163 // computation_map: a map from computation name to HloComputation*. This map 164 // must contain all computations which the newly constructed computation 165 // calls. 166 // add_fused_computation: A function to call to add a fused 167 // computation. Used only when the instruction is a fusion instruction. 168 // fusion_instruction: if non-null then the newly created computation will 169 // be constructed as a fused computation with this instruction as its 170 // fusion parent. 171 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( 172 HloModule* module, const HloComputationProto& proto, 173 const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, 174 const std::function<void(std::unique_ptr<HloComputation>)>& 175 add_fused_computation, 176 HloInstruction* fusion_instruction = nullptr); 177 178 // Gets the instructions in this computation. 179 // 180 // The returned type is a range of HloInstruction*s, so you can iterate over 181 // it using a range-based for loop in the natural way: 182 // 183 // for (HloInstruction* instr : computation->instructions()) { ... } 184 // 185 tensorflow::gtl::iterator_range<UnwrappingIterator< 186 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 187 instructions() const { 188 return {MakeUnwrappingIterator(instructions_.begin()), 189 MakeUnwrappingIterator(instructions_.end())}; 190 } 191 tensorflow::gtl::iterator_range< 192 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 193 instructions() { 194 return {MakeUnwrappingIterator(instructions_.begin()), 195 MakeUnwrappingIterator(instructions_.end())}; 196 } 197 198 // Compute and return a post-order of the instructions in the computation. In 199 // this order, definitions of values always appear before their uses. 200 std::list<HloInstruction*> MakeInstructionPostOrder() const; 201 202 // Computes and returns the reachability between HLO instructions in the 203 // computation. The returned HloReachabilityMap is constructed such that 204 // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a 205 // directed path (from producer to consumer) from 'a' to 'b'. Both data 206 // dependencies (operands) and control dependencies are considered for 207 // reachability. Trivially an instruction is reachable from itself. 208 std::unique_ptr<HloReachabilityMap> ComputeReachability() const; 209 210 // Updates the given reachability map after the immediate predecessor set 211 // (operands and control predecessors) of 'instruction' has changed. 212 void UpdateReachabilityThroughInstruction( 213 const HloInstruction* instruction, HloReachabilityMap* reachability_map); 214 215 int64 instruction_count() const { return instructions_.size(); } 216 217 // Creates and returns a list of the embedded computations called by this 218 // computation. This includes all embedded computations called directly or 219 // transitively. The embedded computations are sorted such that if computation 220 // A calls computation B (eg, via a map instruction) then A will appear after 221 // B in the list. 222 std::list<HloComputation*> MakeEmbeddedComputationsList() const; 223 224 // Creates a fusion instruction containing the given instructions. 225 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion 226 // into a library call. Instructions must be in reverse topological order 227 // (root of the fused expression first). Replaces all uses of the original 228 // root instruction with the fusion instruction. The original instructions are 229 // removed if they have no uses after fusion (this is necessarily true for at 230 // least the root). 231 HloInstruction* CreateFusionInstruction( 232 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, 233 HloInstruction::FusionKind fusion_kind); 234 235 // Create a deep copy of the given instruction and return the instruction 236 // producing the copied result. All instructions performing the copy are added 237 // to the computation. For array-shaped values, this method trivially returns 238 // a kCopy instruction. For tuple-shaped instructions, the copy is performed 239 // with a series of kGetTupleElement and kTuple instructions. If 240 // indices_to_copy is non-null then this ShapeTree indicates which elements 241 // (arrays) of the shape to copy. Non-copied elements are passed through 242 // transparently. If copies_added is non-null, then the added kCopy 243 // instructions will be inserted in the respective index in the given 244 // ShapeTree. 245 StatusOr<HloInstruction*> DeepCopyInstruction( 246 HloInstruction* instruction, 247 const ShapeTree<bool>* indices_to_copy = nullptr, 248 ShapeTree<HloInstruction*>* copies_added = nullptr); 249 250 // Computes and returns the ProgramShape of this computation (shape of 251 // parameters and result without layout). 252 ProgramShape ComputeProgramShape() const; 253 254 // Return whether `*this` and `other` are functionally equivalent. 255 bool operator==(const HloComputation& other) const; 256 257 // Replaces old instruction with newly created instruction. Removes old 258 // instruction from computation. Updates uses and root instruction. 259 Status ReplaceWithNewInstruction( 260 HloInstruction* old_instruction, 261 std::unique_ptr<HloInstruction> new_instruction); 262 263 // Replace old instruction with new instruction. Updates uses and root 264 // instruction. Removes old instruction from computation. Precondition: 265 // old_instruction and new_instruction must have the compatible shapes. 266 Status ReplaceInstruction(HloInstruction* old_instruction, 267 HloInstruction* new_instruction); 268 269 // Set/get the module containing this computation. 270 void set_parent(HloModule* module) { parent_ = module; } 271 const HloModule* parent() const { return parent_; } 272 HloModule* parent() { return parent_; } 273 274 // Visit every node in the computation in DFS post-order with the given 275 // visitor. This is similar to calling HloInstruction::Accept on the root of 276 // the computation except this method also visits instructions not reachable 277 // via the root. The root instruction of the computation is visited last, and 278 // the visitor's FinishVisit method is called once upon completion (with the 279 // root instruction as the argument). 280 template <typename HloInstructionPtr> 281 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const; 282 283 // Same as Accept() above, but the order of operand and control predecessor 284 // visitation is determined by the given operand order; if compare(A, B) == 285 // true, A is visited before B. 286 Status AcceptWithOperandOrder( 287 DfsHloVisitor* visitor, 288 const HloInstruction::CompareFunction& operand_order) const; 289 290 // Visit every node in the computation in the given order. 'order' must 291 // be a topological sort of all instructions in the computation. 292 template <typename HloInstructionPtr> 293 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor, 294 const std::vector<const HloInstruction*>& order) const; 295 296 // Same as Accept() above, but the visitor is given as a function. 297 Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); 298 Status Accept( 299 const std::function<Status(const HloInstruction*)>& visitor_func) const; 300 301 // Returns a deep copy of this computation including all instructions. 302 // If the module pointer is not nullptr, it will be the module where 303 // the cloned computations will be added to (in order to support deep 304 // cloning). 305 std::unique_ptr<HloComputation> Clone(const string& suffix = "clone", 306 HloModule* module = nullptr); 307 308 // Like Clone(), but if an instruction is present in replacement_map, we use 309 // the map's value to replace that instruction in the cloned computation. 310 // 311 // If replacements maps a key to nullptr, we remove that instruction from the 312 // new computation. 313 std::unique_ptr<HloComputation> CloneWithReplacements( 314 std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> 315 replacements, 316 HloModule* module = nullptr, const string& suffix = "clone"); 317 318 // Returns true if the given instruction can be removed from the computation. 319 // Parameter instructions cannot be removed without violating invariants of 320 // the HLO computation with the exception of fusion computation. A parameter 321 // instruction is removable for a fusion computation. 322 // 323 // Note that IsRemovable() is a necessariy condition to remove an instruction 324 // rather than a sufficient condition. For example, instructions with 325 // side-effect (e.g., Send, Infeed) may be removed from a computation, but the 326 // transformation must guarantee the invariants relevant to the instructions 327 // still hold (e.g., Send and Recv must be removed together to make each 328 // channel complete). 329 bool IsRemovable(const HloInstruction* instruction); 330 331 // Returns true if this computation has a side effect. A computation has a 332 // side effect if it contains one or more instructions with a side effect. 333 bool HasSideEffect() const; 334 335 // Returns if this computation is a fusion computation. 336 bool IsFusionComputation() const { return fusion_instruction_ != nullptr; } 337 338 // Returns the owning fusion instruction, or nullptr if this is not a fusion 339 // computation. 340 HloInstruction* FusionInstruction() const { return fusion_instruction_; } 341 void SetFusionInstruction(HloInstruction* fusion_instruction) { 342 fusion_instruction_ = fusion_instruction; 343 } 344 345 private: 346 explicit HloComputation( 347 const string& name, int parameter_count, 348 std::vector<std::unique_ptr<HloInstruction>>* instructions, 349 HloInstruction* root_instruction, HloInstruction* fusion_instruction); 350 351 // Internal helper for adding instructions. 352 HloInstruction* AddInstructionInternal( 353 std::unique_ptr<HloInstruction> instruction); 354 355 // Helper for setting the parent of instructions that are added to this 356 // computation. 357 void Reparent(HloInstruction* instruction); 358 359 // Fuses HLOs in instructions_to_fuse into fusion_instruction. 360 // 361 // Pre-condition: fusion_instruction's opcode is kFusion. 362 void FuseInstructionsInto( 363 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, 364 HloInstruction* fusion_instruction); 365 366 // Internal helper for recursive copying of an instruction. Creates and 367 // returns a deep copy of the given instruction. 368 StatusOr<HloInstruction*> DeepCopyHelper( 369 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, 370 ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index); 371 372 // Internal helper to collect unreachable roots. 373 std::vector<HloInstruction*> CollectUnreachableRoots() const; 374 375 string name_; 376 HloInstruction* root_instruction_; 377 378 // If this computation is a fusion computation, this field points to the 379 // corresponding fusion instruction. Otherwise, this is null. 380 HloInstruction* fusion_instruction_; 381 382 // Module containing this computation. 383 HloModule* parent_ = nullptr; 384 385 // Store instructions in std::list as they can be added and removed 386 // arbitrarily and we want a stable iteration order. Keep a map from 387 // instruction pointer to location in the list for fast lookup. 388 using InstructionList = std::list<std::unique_ptr<HloInstruction>>; 389 InstructionList instructions_; 390 std::unordered_map<const HloInstruction*, InstructionList::iterator> 391 instruction_iterators_; 392 393 std::vector<HloInstruction*> param_instructions_; 394 395 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); 396 }; 397 398 } // namespace xla 399 400 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 401