Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
     18 
     19 #include <list>
     20 #include <memory>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <unordered_set>
     24 #include <utility>
     25 #include <vector>
     26 
     27 #include "tensorflow/compiler/xla/iterator_util.h"
     28 #include "tensorflow/compiler/xla/map_util.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     31 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     33 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
     34 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     35 #include "tensorflow/compiler/xla/shape_tree.h"
     36 #include "tensorflow/compiler/xla/statusor.h"
     37 #include "tensorflow/compiler/xla/types.h"
     38 #include "tensorflow/compiler/xla/xla_data.pb.h"
     39 #include "tensorflow/core/lib/core/status.h"
     40 #include "tensorflow/core/lib/gtl/array_slice.h"
     41 #include "tensorflow/core/lib/gtl/flatmap.h"
     42 #include "tensorflow/core/lib/gtl/flatset.h"
     43 #include "tensorflow/core/platform/macros.h"
     44 #include "tensorflow/core/platform/types.h"
     45 
     46 namespace xla {
     47 
     48 class HloModule;
     49 
     50 // Describes a computation at the HLO level.
     51 //
     52 // An HloComputation contains a directed acyclic graph of HLO instructions. The
     53 // computation has a single root instruction which produces the output of the
     54 // computation.
     55 class HloComputation {
     56  public:
     57   // Builder class for HloComputation.
     58   class Builder {
     59    public:
     60     explicit Builder(const string& name,
     61                      HloInstruction* fusion_instruction = nullptr)
     62         : name_(name),
     63           last_added_instruction_(nullptr),
     64           fusion_instruction_(fusion_instruction) {}
     65 
     66     // Build and return an HloComputation. The parameter root_instruction
     67     // specifies the already-added instruction to use as the root. If
     68     // root_instruction is nullptr then use the last added instruction as the
     69     // root.
     70     std::unique_ptr<HloComputation> Build(
     71         HloInstruction* root_instruction = nullptr);
     72 
     73     HloInstruction* AddInstruction(
     74         std::unique_ptr<HloInstruction> instruction) {
     75       instructions_.push_back(std::move(instruction));
     76       last_added_instruction_ = instructions_.back().get();
     77       return last_added_instruction_;
     78     }
     79 
     80     Status ForEachInstruction(
     81         const std::function<Status(const HloInstruction*)>& func) const {
     82       for (const auto& instruction : instructions_) {
     83         TF_RETURN_IF_ERROR(func(instruction.get()));
     84       }
     85       return Status::OK();
     86     }
     87 
     88    private:
     89     const string name_;
     90     HloInstruction* last_added_instruction_;
     91     HloInstruction* fusion_instruction_;
     92     std::vector<std::unique_ptr<HloInstruction>> instructions_;
     93   };
     94 
     95   // Add an instruction to the computation. The computation takes ownership of
     96   // the instruction.
     97   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
     98 
     99   // Remove the param_no'th parameter from the computation.
    100   // Note this is only applicatable to the computation for the fusion
    101   // instruction.
    102   Status RemoveParameter(int64 param_no);
    103 
    104   // Add new parameter instruction to the computation.
    105   // This should be a new parameter. Instruction will be appended to parameters
    106   // and inserted to the instruction list.
    107   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
    108 
    109   // Remove an instruction from the computation. The instruction must have no
    110   // users. Instruction is deallocated with this call.
    111   Status RemoveInstruction(HloInstruction* instruction);
    112 
    113   // Remove an instruction from the computation and also transitively any
    114   // operand that has no users post removing an instruction. The instruction
    115   // must have no users. Instruction is deallocated with this call.
    116   Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
    117 
    118   // Set the root of the computation to the given instruction. The instruction
    119   // must have already been added to the computation and have the same shape as
    120   // the result of the computation for non fusion computations.
    121   void set_root_instruction(HloInstruction* new_root_instruction);
    122 
    123   // Return the root instruction of the computation. The root instruction is the
    124   // instruction which produces the output of the computation.
    125   HloInstruction* root_instruction() const { return root_instruction_; }
    126 
    127   // Returns the number of parameters for this computation.
    128   int64 num_parameters() const { return param_instructions_.size(); }
    129 
    130   // Returns the parameter instruction for the given parameter number.
    131   HloInstruction* parameter_instruction(int64 param_no) const {
    132     CHECK_GE(param_no, 0);
    133     CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
    134         << "Computation " << name() << " has no parameter number " << param_no;
    135     return param_instructions_[param_no];
    136   }
    137 
    138   const std::vector<HloInstruction*>& parameter_instructions() const {
    139     return param_instructions_;
    140   }
    141 
    142   const string& name() const { return name_; }
    143 
    144   // Use the given NameUniquer to select a unique name for the computation based
    145   // on the computation's existing name.
    146   void UniquifyName(NameUniquer* name_uniquer);
    147 
    148   // Return a string representation of the computation.
    149   //
    150   // (We express the default options using an overload rather than a default
    151   // param because gdb ignores default params, but does resolve overloads.)
    152   string ToString() const { return ToString(HloPrintOptions()); }
    153   string ToString(const HloPrintOptions& options) const;
    154 
    155   // Returns a serialized representation of this computation.
    156   HloComputationProto ToProto() const;
    157 
    158   // Creates a computation from the given proto. Arguments:
    159   //
    160   //   module: the module which will contain the computation. The newly created
    161   //     computation is *not* added to the module, however.
    162   //   proto: the proto to convert from.
    163   //   computation_map: a map from computation name to HloComputation*. This map
    164   //     must contain all computations which the newly constructed computation
    165   //     calls.
    166   //   add_fused_computation: A function to call to add a fused
    167   //     computation. Used only when the instruction is a fusion instruction.
    168   //   fusion_instruction: if non-null then the newly created computation will
    169   //     be constructed as a fused computation with this instruction as its
    170   //     fusion parent.
    171   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
    172       HloModule* module, const HloComputationProto& proto,
    173       const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
    174       const std::function<void(std::unique_ptr<HloComputation>)>&
    175           add_fused_computation,
    176       HloInstruction* fusion_instruction = nullptr);
    177 
    178   // Gets the instructions in this computation.
    179   //
    180   // The returned type is a range of HloInstruction*s, so you can iterate over
    181   // it using a range-based for loop in the natural way:
    182   //
    183   //   for (HloInstruction* instr : computation->instructions()) { ... }
    184   //
    185   tensorflow::gtl::iterator_range<UnwrappingIterator<
    186       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
    187   instructions() const {
    188     return {MakeUnwrappingIterator(instructions_.begin()),
    189             MakeUnwrappingIterator(instructions_.end())};
    190   }
    191   tensorflow::gtl::iterator_range<
    192       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
    193   instructions() {
    194     return {MakeUnwrappingIterator(instructions_.begin()),
    195             MakeUnwrappingIterator(instructions_.end())};
    196   }
    197 
    198   // Compute and return a post-order of the instructions in the computation. In
    199   // this order, definitions of values always appear before their uses.
    200   std::list<HloInstruction*> MakeInstructionPostOrder() const;
    201 
    202   // Computes and returns the reachability between HLO instructions in the
    203   // computation. The returned HloReachabilityMap is constructed such that
    204   // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
    205   // directed path (from producer to consumer) from 'a' to 'b'. Both data
    206   // dependencies (operands) and control dependencies are considered for
    207   // reachability. Trivially an instruction is reachable from itself.
    208   std::unique_ptr<HloReachabilityMap> ComputeReachability() const;
    209 
    210   // Updates the given reachability map after the immediate predecessor set
    211   // (operands and control predecessors) of 'instruction' has changed.
    212   void UpdateReachabilityThroughInstruction(
    213       const HloInstruction* instruction, HloReachabilityMap* reachability_map);
    214 
    215   int64 instruction_count() const { return instructions_.size(); }
    216 
    217   // Creates and returns a list of the embedded computations called by this
    218   // computation. This includes all embedded computations called directly or
    219   // transitively. The embedded computations are sorted such that if computation
    220   // A calls computation B (eg, via a map instruction) then A will appear after
    221   // B in the list.
    222   std::list<HloComputation*> MakeEmbeddedComputationsList() const;
    223 
    224   // Creates a fusion instruction containing the given instructions.
    225   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
    226   // into a library call. Instructions must be in reverse topological order
    227   // (root of the fused expression first). Replaces all uses of the original
    228   // root instruction with the fusion instruction. The original instructions are
    229   // removed if they have no uses after fusion (this is necessarily true for at
    230   // least the root).
    231   HloInstruction* CreateFusionInstruction(
    232       tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
    233       HloInstruction::FusionKind fusion_kind);
    234 
    235   // Create a deep copy of the given instruction and return the instruction
    236   // producing the copied result. All instructions performing the copy are added
    237   // to the computation. For array-shaped values, this method trivially returns
    238   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
    239   // with a series of kGetTupleElement and kTuple instructions. If
    240   // indices_to_copy is non-null then this ShapeTree indicates which elements
    241   // (arrays) of the shape to copy. Non-copied elements are passed through
    242   // transparently. If copies_added is non-null, then the added kCopy
    243   // instructions will be inserted in the respective index in the given
    244   // ShapeTree.
    245   StatusOr<HloInstruction*> DeepCopyInstruction(
    246       HloInstruction* instruction,
    247       const ShapeTree<bool>* indices_to_copy = nullptr,
    248       ShapeTree<HloInstruction*>* copies_added = nullptr);
    249 
    250   // Computes and returns the ProgramShape of this computation (shape of
    251   // parameters and result without layout).
    252   ProgramShape ComputeProgramShape() const;
    253 
    254   // Return whether `*this` and `other` are functionally equivalent.
    255   bool operator==(const HloComputation& other) const;
    256 
    257   // Replaces old instruction with newly created instruction. Removes old
    258   // instruction from computation. Updates uses and root instruction.
    259   Status ReplaceWithNewInstruction(
    260       HloInstruction* old_instruction,
    261       std::unique_ptr<HloInstruction> new_instruction);
    262 
    263   // Replace old instruction with new instruction.  Updates uses and root
    264   // instruction. Removes old instruction from computation. Precondition:
    265   // old_instruction and new_instruction must have the compatible shapes.
    266   Status ReplaceInstruction(HloInstruction* old_instruction,
    267                             HloInstruction* new_instruction);
    268 
    269   // Set/get the module containing this computation.
    270   void set_parent(HloModule* module) { parent_ = module; }
    271   const HloModule* parent() const { return parent_; }
    272   HloModule* parent() { return parent_; }
    273 
    274   // Visit every node in the computation in DFS post-order with the given
    275   // visitor. This is similar to calling HloInstruction::Accept on the root of
    276   // the computation except this method also visits instructions not reachable
    277   // via the root. The root instruction of the computation is visited last, and
    278   // the visitor's FinishVisit method is called once upon completion (with the
    279   // root instruction as the argument).
    280   template <typename HloInstructionPtr>
    281   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
    282 
    283   // Same as Accept() above, but the order of operand and control predecessor
    284   // visitation is determined by the given operand order; if compare(A, B) ==
    285   // true, A is visited before B.
    286   Status AcceptWithOperandOrder(
    287       DfsHloVisitor* visitor,
    288       const HloInstruction::CompareFunction& operand_order) const;
    289 
    290   // Visit every node in the computation in the given order. 'order' must
    291   // be a topological sort of all instructions in the computation.
    292   template <typename HloInstructionPtr>
    293   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
    294                        const std::vector<const HloInstruction*>& order) const;
    295 
    296   // Same as Accept() above, but the visitor is given as a function.
    297   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
    298   Status Accept(
    299       const std::function<Status(const HloInstruction*)>& visitor_func) const;
    300 
    301   // Returns a deep copy of this computation including all instructions.
    302   // If the module pointer is not nullptr, it will be the module where
    303   // the cloned computations will be added to (in order to support deep
    304   // cloning).
    305   std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
    306                                         HloModule* module = nullptr);
    307 
    308   // Like Clone(), but if an instruction is present in replacement_map, we use
    309   // the map's value to replace that instruction in the cloned computation.
    310   //
    311   // If replacements maps a key to nullptr, we remove that instruction from the
    312   // new computation.
    313   std::unique_ptr<HloComputation> CloneWithReplacements(
    314       std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
    315           replacements,
    316       HloModule* module = nullptr, const string& suffix = "clone");
    317 
    318   // Returns true if the given instruction can be removed from the computation.
    319   // Parameter instructions cannot be removed without violating invariants of
    320   // the HLO computation with the exception of fusion computation. A parameter
    321   // instruction is removable for a fusion computation.
    322   //
    323   // Note that IsRemovable() is a necessariy condition to remove an instruction
    324   // rather than a sufficient condition. For example, instructions with
    325   // side-effect (e.g., Send, Infeed) may be removed from a computation, but the
    326   // transformation must guarantee the invariants relevant to the instructions
    327   // still hold (e.g., Send and Recv must be removed together to make each
    328   // channel complete).
    329   bool IsRemovable(const HloInstruction* instruction);
    330 
    331   // Returns true if this computation has a side effect. A computation has a
    332   // side effect if it contains one or more instructions with a side effect.
    333   bool HasSideEffect() const;
    334 
    335   // Returns if this computation is a fusion computation.
    336   bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
    337 
    338   // Returns the owning fusion instruction, or nullptr if this is not a fusion
    339   // computation.
    340   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
    341   void SetFusionInstruction(HloInstruction* fusion_instruction) {
    342     fusion_instruction_ = fusion_instruction;
    343   }
    344 
    345  private:
    346   explicit HloComputation(
    347       const string& name, int parameter_count,
    348       std::vector<std::unique_ptr<HloInstruction>>* instructions,
    349       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
    350 
    351   // Internal helper for adding instructions.
    352   HloInstruction* AddInstructionInternal(
    353       std::unique_ptr<HloInstruction> instruction);
    354 
    355   // Helper for setting the parent of instructions that are added to this
    356   // computation.
    357   void Reparent(HloInstruction* instruction);
    358 
    359   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
    360   //
    361   // Pre-condition: fusion_instruction's opcode is kFusion.
    362   void FuseInstructionsInto(
    363       tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
    364       HloInstruction* fusion_instruction);
    365 
    366   // Internal helper for recursive copying of an instruction. Creates and
    367   // returns a deep copy of the given instruction.
    368   StatusOr<HloInstruction*> DeepCopyHelper(
    369       HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
    370       ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
    371 
    372   // Internal helper to collect unreachable roots.
    373   std::vector<HloInstruction*> CollectUnreachableRoots() const;
    374 
    375   string name_;
    376   HloInstruction* root_instruction_;
    377 
    378   // If this computation is a fusion computation, this field points to the
    379   // corresponding fusion instruction.  Otherwise, this is null.
    380   HloInstruction* fusion_instruction_;
    381 
    382   // Module containing this computation.
    383   HloModule* parent_ = nullptr;
    384 
    385   // Store instructions in std::list as they can be added and removed
    386   // arbitrarily and we want a stable iteration order. Keep a map from
    387   // instruction pointer to location in the list for fast lookup.
    388   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
    389   InstructionList instructions_;
    390   std::unordered_map<const HloInstruction*, InstructionList::iterator>
    391       instruction_iterators_;
    392 
    393   std::vector<HloInstruction*> param_instructions_;
    394 
    395   TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
    396 };
    397 
    398 }  // namespace xla
    399 
    400 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
    401