Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Analysis for determining the possible set of values for all positions
     17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
     18 // tracking values across computation boundaries.
     19 
     20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
     21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
     22 
     23 #include <memory>
     24 #include <string>
     25 #include <unordered_map>
     26 #include <vector>
     27 
     28 #include "tensorflow/compiler/xla/service/call_graph.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_value.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/status.h"
     34 #include "tensorflow/compiler/xla/statusor.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/xla_data.pb.h"
     37 #include "tensorflow/core/lib/gtl/array_slice.h"
     38 #include "tensorflow/core/platform/macros.h"
     39 
     40 namespace xla {
     41 
     42 // Analysis which identifies all HLO values and their uses in an HLO module.
     43 class HloDataflowAnalysis {
     44  public:
     45   // Run dataflow analysis on the given module. Parameters:
     46   //
     47   //   ssa_form : If true then new values are defined at the merge points of
     48   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
     49   //     values".  The merge is formed by the init value and loop backedge. The
     50   //     SSA form is minimal in that a new phi value is defined only if the
     51   //     merge point is reachable by multiple different values. The SSA form is
     52   //     also in loop-closed form in that no values defined inside of a loop
     53   //     (while body) is used outside of the loop.
     54   //
     55   //     If ssa_form is false, then merge points do not define new
     56   //     values. Rather, the HloValueSet for the merge point contains the union
     57   //     of the merged HloValues.
     58   //
     59   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
     60   //     a new HLO value in the analysis. If false then Bitcast forwards the
     61   //     value of its operand.
     62   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
     63       const HloModule& module, bool ssa_form = false,
     64       bool bitcast_defines_value = false);
     65 
     66   // Returns true if 'instruction' defines an HLO value at the given shape index
     67   // of its output.
     68   bool ValueIsDefinedAt(const HloInstruction* instruction,
     69                         const ShapeIndex& index = {}) const;
     70 
     71   // Return the HloValue defined by 'instruction' at the given shape index of
     72   // its output.
     73   //
     74   // Precondition: ValueIsDefinedAt is true for this instruction and index.
     75   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
     76                                     const ShapeIndex& index = {}) const;
     77   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
     78                               const ShapeIndex& index = {});
     79 
     80   // Return the InstructionValueSet for the given instruction.
     81   const InstructionValueSet& GetInstructionValueSet(
     82       const HloInstruction* instruction) const;
     83   InstructionValueSet& GetInstructionValueSet(
     84       const HloInstruction* instruction);
     85 
     86   // Return the HloValueSet for the given instruction at the given index or the
     87   // given position.
     88   const HloValueSet& GetValueSet(const HloInstruction* instruction,
     89                                  const ShapeIndex& index = {}) const;
     90   const HloValueSet& GetValueSet(const HloPosition& position) const;
     91   HloValueSet& GetValueSet(const HloPosition& position);
     92   HloValueSet& GetValueSet(const HloInstruction* instruction,
     93                            const ShapeIndex& index = {});
     94 
     95   // Return the unique value in the HloValueSet at the given instruction and
     96   // shape index. CHECKs if the value set does not contain a exactly one value.
     97   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
     98                                    const ShapeIndex& index = {}) const {
     99     return GetValueSet(instruction, index).GetUniqueValue();
    100   }
    101   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
    102                              const ShapeIndex& index = {}) {
    103     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
    104   }
    105 
    106   // Return the HloValue with the given Id.
    107   const HloValue& GetValue(HloValue::Id value_id) const;
    108   HloValue& GetValue(HloValue::Id value_id);
    109 
    110   // Return the total number of HloValues.
    111   int64 value_count() const { return values_.size(); }
    112 
    113   // Return a vector of all HloValues stabily sorted by HloValue::Id.
    114   const std::vector<const HloValue*>& values() const { return values_vector_; }
    115 
    116   // Return the call graph used for computing the dataflow.
    117   const CallGraph& call_graph() const { return *call_graph_; }
    118 
    119   string ToString() const;
    120 
    121  protected:
    122   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
    123                       bool bitcast_defines_value = false);
    124 
    125   // Returns a new HloValue defined at the given instruction and shape index.
    126   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
    127                         bool is_phi = false);
    128 
    129   // Mark the HloValue with the given ID for deletion.
    130   void MarkValueForDeletion(HloValue::Id value_id);
    131 
    132   // Delete all HloValues marked for deletion. Should be called after
    133   // propagation is complete.
    134   void DeleteMarkedValues();
    135 
    136   // Constructs and initializes the InstructionValueSets of all instructions to
    137   // contain exactly the HloValues defined by each instruction. These values can
    138   // then propagated throughout the HLO graph by calling Propagate.
    139   Status InitializeInstructionValueSets();
    140 
    141   // Updates the value set of the given instruction based on the values flowing
    142   // into the instruction (operands and cross-computation dataflow).
    143   bool UpdateInstructionValueSet(HloInstruction* instruction);
    144 
    145   // Updates the value set for a particular instruction type. Returns whether
    146   // the instruction value set changed.
    147   bool UpdateBitcastValueSet(HloInstruction* bitcast);
    148   bool UpdateSliceValueSet(HloInstruction* slice);
    149   bool UpdateCallValueSet(HloInstruction* call);
    150   bool UpdateConditionalValueSet(HloInstruction* conditional);
    151   bool UpdateCopyValueSet(HloInstruction* copy);
    152   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
    153   bool UpdateParameterValueSet(HloInstruction* parameter);
    154   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
    155   bool UpdateSelectValueSet(HloInstruction* select);
    156   bool UpdateSendValueSet(HloInstruction* send);
    157   bool UpdateTupleValueSet(HloInstruction* tuple);
    158   bool UpdateWhileValueSet(HloInstruction* xla_while);
    159 
    160   // Propagate the dataflow through the module.
    161   void Propagate();
    162 
    163   // Return the result of the SSA Phi function applied to the given inputs at
    164   // the given instruction. If skip_top_level is true, then the top level of the
    165   // value set of 'instruction' is not modified.
    166   bool Phi(HloInstruction* instruction,
    167            tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
    168 
    169   // Updates the positions of the HloValues in the output of the given
    170   // instruction. This should be called after the instruction value set of
    171   // 'instruction' has been changed. 'prev_value_set' must point to the previous
    172   // state of the value set prior to the change. 'prev_value_set' may be null if
    173   // this is the first time positions are being computed. The previous state is
    174   // necessary to efficiently remove positions which have been eliminated due to
    175   // changes in the instructions' InstructionValueSet.
    176   void UpdatePositionsOfValuesAt(
    177       HloInstruction* instruction, const InstructionValueSet& new_value_set,
    178       const InstructionValueSet* prev_value_set = nullptr);
    179 
    180   // Verify various invariants of the dataflow analysis.
    181   Status Verify() const;
    182 
    183   const HloModule& module_;
    184   const bool ssa_form_;
    185   const bool bitcast_defines_value_;
    186 
    187   std::unique_ptr<CallGraph> call_graph_;
    188 
    189   // The map of all HloValues in the module. We pass around pointers to the
    190   // mapped HloValues, so the underlying container must keep them valid despite
    191   // mutations touching other map entries.
    192   std::unordered_map<HloValue::Id, HloValue> values_;
    193 
    194   // A map from instruction to InstructionValueSet.
    195   std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
    196 
    197   // Values marked for deletion during construction. We don't delete them
    198   // immediately because references to them may remain in ValueSets temporarily
    199   // during propagation. After construction, these values are deleted.
    200   std::vector<HloValue::Id> value_ids_to_delete_;
    201 
    202   // A vector containing all HloValues sorted by HloValue::Id.
    203   std::vector<const HloValue*> values_vector_;
    204 
    205   // The Id to use for the next HloValue.
    206   HloValue::Id next_value_id_ = 0;
    207 };
    208 
    209 }  // namespace xla
    210 
    211 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
    212