Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Analysis for determining the possible set of values for all positions
     17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
     18 // tracking values across computation boundaries.
     19 
     20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
     21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
     22 
     23 #include <memory>
     24 #include <string>
     25 #include <unordered_map>
     26 #include <vector>
     27 
     28 #include "absl/types/span.h"
     29 #include "tensorflow/compiler/xla/service/call_graph.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module.h"
     32 #include "tensorflow/compiler/xla/service/hlo_value.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/status.h"
     35 #include "tensorflow/compiler/xla/statusor.h"
     36 #include "tensorflow/compiler/xla/types.h"
     37 #include "tensorflow/compiler/xla/xla_data.pb.h"
     38 #include "tensorflow/core/platform/macros.h"
     39 
     40 namespace xla {
     41 
     42 // Analysis which identifies all HLO values and their uses in an HLO module.
     43 class HloDataflowAnalysis {
     44  public:
     45   // Different backends can have very different ways to do fusion, so we give
     46   // backends the flexibility to decide whether an fusion instruction can share
     47   // buffer with it's operands. If this is not specified, a default strategy
     48   // will be used; if this is specified, it will be applied *in addition* to the
     49   // default strategy.
     50   //
     51   // The first parameter of the function should be the fusion instruction, the
     52   // second parameter should be an operand of the fusion instruction.
     53   //
     54   // TODO(b/80315712): Find a better way to tell whether a fusion can share
     55   // buffer.
     56   using FusionCanShareBufferFunction = std::function<bool(
     57       const HloInstruction* fusion, const HloInstruction* operand)>;
     58 
     59   // Run dataflow analysis on the given module. Parameters:
     60   //
     61   //   ssa_form : If true then new values are defined at the merge points of
     62   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
     63   //     values".  The merge is formed by the init value and loop backedge. The
     64   //     SSA form is minimal in that a new phi value is defined only if the
     65   //     merge point is reachable by multiple different values. The SSA form is
     66   //     also in loop-closed form in that no values defined inside of a loop
     67   //     (while body) is used outside of the loop.
     68   //
     69   //     If ssa_form is false, then merge points do not define new
     70   //     values. Rather, the HloValueSet for the merge point contains the union
     71   //     of the merged HloValues.
     72   //
     73   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
     74   //     a new HLO value in the analysis. If false then Bitcast forwards the
     75   //     value of its operand.
     76   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
     77       const HloModule& module, bool ssa_form = false,
     78       bool bitcast_defines_value = false,
     79       const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
     80 
     81   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
     82 
     83   // Returns true if 'instruction' defines an HLO value at the given shape index
     84   // of its output.
     85   bool ValueIsDefinedAt(const HloInstruction* instruction,
     86                         const ShapeIndex& index = {}) const;
     87 
     88   // Return the HloValue defined by 'instruction' at the given shape index of
     89   // its output.
     90   //
     91   // Precondition: ValueIsDefinedAt is true for this instruction and index.
     92   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
     93                                     const ShapeIndex& index = {}) const;
     94   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
     95                               const ShapeIndex& index = {});
     96 
     97   // Return the InstructionValueSet for the given instruction.
     98   const InstructionValueSet& GetInstructionValueSet(
     99       const HloInstruction* instruction) const;
    100   InstructionValueSet& GetInstructionValueSet(
    101       const HloInstruction* instruction);
    102 
    103   // Return the HloValueSet for the given instruction at the given index or the
    104   // given position.
    105   const HloValueSet& GetValueSet(const HloInstruction* instruction,
    106                                  const ShapeIndex& index = {}) const;
    107   const HloValueSet& GetValueSet(const HloPosition& position) const;
    108   HloValueSet& GetValueSet(const HloPosition& position);
    109   HloValueSet& GetValueSet(const HloInstruction* instruction,
    110                            const ShapeIndex& index = {});
    111 
    112   // Return the unique value in the HloValueSet at the given instruction and
    113   // shape index. CHECKs if the value set does not contain a exactly one value.
    114   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
    115                                    const ShapeIndex& index = {}) const {
    116     return GetValueSet(instruction, index).GetUniqueValue();
    117   }
    118   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
    119                              const ShapeIndex& index = {}) {
    120     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
    121   }
    122 
    123   // Return the HloValue with the given Id.
    124   const HloValue& GetValue(HloValue::Id value_id) const;
    125   HloValue& GetValue(HloValue::Id value_id);
    126 
    127   // Return the total number of HloValues.
    128   int64 value_count() const { return values_.size(); }
    129 
    130   // Return a vector of all HloValues stabily sorted by HloValue::Id.
    131   const std::vector<const HloValue*>& values() const { return values_vector_; }
    132 
    133   // Return the call graph used for computing the dataflow.
    134   const CallGraph& call_graph() const { return *call_graph_; }
    135 
    136   string ToString() const;
    137 
    138   // Returns true if 'user' cannot possibly use the buffer at 'index' in
    139   // 'operand'. Returns false otherwise.
    140   //
    141   // 'operand' does not have to be an operand of 'user'. This can be the case
    142   // with indirect uses.
    143   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
    144                                const ShapeIndex& index,
    145                                const HloInstruction* user) const;
    146 
    147   // Returns true if 'user' (at 'user_index') can share a buffer with its
    148   // operand 'operand' (at 'operand_index'). Returns false otherwise.
    149   //
    150   // REQUIRES: 'operand' is an operand of 'user'.
    151   bool CanShareOperandBufferWithUser(HloInstruction* operand,
    152                                      const ShapeIndex& operand_index,
    153                                      HloInstruction* user,
    154                                      const ShapeIndex& user_index) const;
    155 
    156  protected:
    157   HloDataflowAnalysis(
    158       const HloModule& module, bool ssa_form,
    159       bool bitcast_defines_value = false,
    160       const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
    161 
    162   // Returns a new HloValue defined at the given instruction and shape index.
    163   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
    164                         bool is_phi = false);
    165 
    166   // Mark the HloValue with the given ID for deletion.
    167   void MarkValueForDeletion(HloValue::Id value_id);
    168 
    169   // Delete all HloValues marked for deletion. Should be called after
    170   // propagation is complete.
    171   void DeleteMarkedValues();
    172 
    173   // Constructs and initializes the InstructionValueSets of all instructions to
    174   // contain exactly the HloValues defined by each instruction. These values can
    175   // then propagated throughout the HLO graph by calling Propagate.
    176   Status InitializeInstructionValueSets();
    177 
    178   // Updates the value set of the given instruction based on the values flowing
    179   // into the instruction (operands and cross-computation dataflow).
    180   bool UpdateInstructionValueSet(HloInstruction* instruction);
    181 
    182   // Updates the value set for a particular instruction type. Returns whether
    183   // the instruction value set changed.
    184   bool UpdateBitcastValueSet(HloInstruction* bitcast);
    185   bool UpdateCallValueSet(HloInstruction* call);
    186   bool UpdateConditionalValueSet(HloInstruction* conditional);
    187   bool UpdateCopyValueSet(HloInstruction* copy);
    188   bool UpdateDomainValueSet(HloInstruction* domain);
    189   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
    190   bool UpdateParameterValueSet(HloInstruction* parameter);
    191   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
    192   bool UpdateTupleSelectValueSet(HloInstruction* select);
    193   bool UpdateSendValueSet(HloInstruction* send);
    194   bool UpdateTupleValueSet(HloInstruction* tuple);
    195   bool UpdateWhileValueSet(HloInstruction* xla_while);
    196   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
    197 
    198   // Propagate the dataflow through the module.
    199   void Propagate();
    200 
    201   // Return the result of the SSA Phi function applied to the given inputs at
    202   // the given instruction. If skip_top_level is true, then the top level of the
    203   // value set of 'instruction' is not modified.
    204   bool Phi(HloInstruction* instruction,
    205            absl::Span<const InstructionValueSet* const> inputs);
    206 
    207   // Updates the positions of the HloValues in the output of the given
    208   // instruction. This should be called after the instruction value set of
    209   // 'instruction' has been changed. 'prev_value_set' must point to the previous
    210   // state of the value set prior to the change. 'prev_value_set' may be null if
    211   // this is the first time positions are being computed. The previous state is
    212   // necessary to efficiently remove positions which have been eliminated due to
    213   // changes in the instructions' InstructionValueSet.
    214   void UpdatePositionsOfValuesAt(
    215       HloInstruction* instruction, const InstructionValueSet& new_value_set,
    216       const InstructionValueSet* prev_value_set = nullptr);
    217 
    218   // Verify various invariants of the dataflow analysis.
    219   Status Verify() const;
    220 
    221   const HloModule& module_;
    222   const bool ssa_form_;
    223   const bool bitcast_defines_value_;
    224 
    225   std::unique_ptr<CallGraph> call_graph_;
    226 
    227   // The map of all HloValues in the module. We pass around pointers to the
    228   // mapped HloValues, so the underlying container must keep them valid despite
    229   // mutations touching other map entries.
    230   std::unordered_map<HloValue::Id, HloValue> values_;
    231 
    232   // A map from instruction to InstructionValueSet.
    233   std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
    234 
    235   // Values marked for deletion during construction. We don't delete them
    236   // immediately because references to them may remain in ValueSets temporarily
    237   // during propagation. After construction, these values are deleted.
    238   std::vector<HloValue::Id> value_ids_to_delete_;
    239 
    240   // A vector containing all HloValues sorted by HloValue::Id.
    241   std::vector<const HloValue*> values_vector_;
    242 
    243   // The Id to use for the next HloValue.
    244   HloValue::Id next_value_id_ = 0;
    245 
    246   // Backend specific function that decides whether a fusion can share buffer
    247   // with its operand.
    248   FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr;
    249 };
    250 
    251 }  // namespace xla
    252 
    253 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
    254