1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Analysis for determining the possible set of values for all positions 17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped 18 // tracking values across computation boundaries. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 22 23 #include <memory> 24 #include <string> 25 #include <unordered_map> 26 #include <vector> 27 28 #include "tensorflow/compiler/xla/service/call_graph.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/hlo_value.h" 32 #include "tensorflow/compiler/xla/shape_util.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/gtl/array_slice.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace xla { 41 42 // Analysis which identifies all HLO values and their uses in an HLO module. 43 class HloDataflowAnalysis { 44 public: 45 // Run dataflow analysis on the given module. Parameters: 46 // 47 // ssa_form : If true then new values are defined at the merge points of 48 // kWhile instructions. Abusing nomenclature somewhat, we call these "phi 49 // values". The merge is formed by the init value and loop backedge. The 50 // SSA form is minimal in that a new phi value is defined only if the 51 // merge point is reachable by multiple different values. The SSA form is 52 // also in loop-closed form in that no values defined inside of a loop 53 // (while body) is used outside of the loop. 54 // 55 // If ssa_form is false, then merge points do not define new 56 // values. Rather, the HloValueSet for the merge point contains the union 57 // of the merged HloValues. 58 // 59 // bitcast_defines_value : If true then the Bitcast HLO instruction defines 60 // a new HLO value in the analysis. If false then Bitcast forwards the 61 // value of its operand. 62 static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( 63 const HloModule& module, bool ssa_form = false, 64 bool bitcast_defines_value = false); 65 66 // Returns true if 'instruction' defines an HLO value at the given shape index 67 // of its output. 68 bool ValueIsDefinedAt(const HloInstruction* instruction, 69 const ShapeIndex& index = {}) const; 70 71 // Return the HloValue defined by 'instruction' at the given shape index of 72 // its output. 73 // 74 // Precondition: ValueIsDefinedAt is true for this instruction and index. 75 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 76 const ShapeIndex& index = {}) const; 77 HloValue& GetValueDefinedAt(const HloInstruction* instruction, 78 const ShapeIndex& index = {}); 79 80 // Return the InstructionValueSet for the given instruction. 81 const InstructionValueSet& GetInstructionValueSet( 82 const HloInstruction* instruction) const; 83 InstructionValueSet& GetInstructionValueSet( 84 const HloInstruction* instruction); 85 86 // Return the HloValueSet for the given instruction at the given index or the 87 // given position. 88 const HloValueSet& GetValueSet(const HloInstruction* instruction, 89 const ShapeIndex& index = {}) const; 90 const HloValueSet& GetValueSet(const HloPosition& position) const; 91 HloValueSet& GetValueSet(const HloPosition& position); 92 HloValueSet& GetValueSet(const HloInstruction* instruction, 93 const ShapeIndex& index = {}); 94 95 // Return the unique value in the HloValueSet at the given instruction and 96 // shape index. CHECKs if the value set does not contain a exactly one value. 97 const HloValue& GetUniqueValueAt(const HloInstruction* instruction, 98 const ShapeIndex& index = {}) const { 99 return GetValueSet(instruction, index).GetUniqueValue(); 100 } 101 HloValue& GetUniqueValueAt(const HloInstruction* instruction, 102 const ShapeIndex& index = {}) { 103 return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); 104 } 105 106 // Return the HloValue with the given Id. 107 const HloValue& GetValue(HloValue::Id value_id) const; 108 HloValue& GetValue(HloValue::Id value_id); 109 110 // Return the total number of HloValues. 111 int64 value_count() const { return values_.size(); } 112 113 // Return a vector of all HloValues stabily sorted by HloValue::Id. 114 const std::vector<const HloValue*>& values() const { return values_vector_; } 115 116 // Return the call graph used for computing the dataflow. 117 const CallGraph& call_graph() const { return *call_graph_; } 118 119 string ToString() const; 120 121 protected: 122 HloDataflowAnalysis(const HloModule& module, bool ssa_form, 123 bool bitcast_defines_value = false); 124 125 // Returns a new HloValue defined at the given instruction and shape index. 126 HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, 127 bool is_phi = false); 128 129 // Mark the HloValue with the given ID for deletion. 130 void MarkValueForDeletion(HloValue::Id value_id); 131 132 // Delete all HloValues marked for deletion. Should be called after 133 // propagation is complete. 134 void DeleteMarkedValues(); 135 136 // Constructs and initializes the InstructionValueSets of all instructions to 137 // contain exactly the HloValues defined by each instruction. These values can 138 // then propagated throughout the HLO graph by calling Propagate. 139 Status InitializeInstructionValueSets(); 140 141 // Updates the value set of the given instruction based on the values flowing 142 // into the instruction (operands and cross-computation dataflow). 143 bool UpdateInstructionValueSet(HloInstruction* instruction); 144 145 // Updates the value set for a particular instruction type. Returns whether 146 // the instruction value set changed. 147 bool UpdateBitcastValueSet(HloInstruction* bitcast); 148 bool UpdateSliceValueSet(HloInstruction* slice); 149 bool UpdateCallValueSet(HloInstruction* call); 150 bool UpdateConditionalValueSet(HloInstruction* conditional); 151 bool UpdateCopyValueSet(HloInstruction* copy); 152 bool UpdateGetTupleElementValueSet(HloInstruction* gte); 153 bool UpdateParameterValueSet(HloInstruction* parameter); 154 bool UpdateRecvDoneValueSet(HloInstruction* recv_done); 155 bool UpdateSelectValueSet(HloInstruction* select); 156 bool UpdateSendValueSet(HloInstruction* send); 157 bool UpdateTupleValueSet(HloInstruction* tuple); 158 bool UpdateWhileValueSet(HloInstruction* xla_while); 159 160 // Propagate the dataflow through the module. 161 void Propagate(); 162 163 // Return the result of the SSA Phi function applied to the given inputs at 164 // the given instruction. If skip_top_level is true, then the top level of the 165 // value set of 'instruction' is not modified. 166 bool Phi(HloInstruction* instruction, 167 tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs); 168 169 // Updates the positions of the HloValues in the output of the given 170 // instruction. This should be called after the instruction value set of 171 // 'instruction' has been changed. 'prev_value_set' must point to the previous 172 // state of the value set prior to the change. 'prev_value_set' may be null if 173 // this is the first time positions are being computed. The previous state is 174 // necessary to efficiently remove positions which have been eliminated due to 175 // changes in the instructions' InstructionValueSet. 176 void UpdatePositionsOfValuesAt( 177 HloInstruction* instruction, const InstructionValueSet& new_value_set, 178 const InstructionValueSet* prev_value_set = nullptr); 179 180 // Verify various invariants of the dataflow analysis. 181 Status Verify() const; 182 183 const HloModule& module_; 184 const bool ssa_form_; 185 const bool bitcast_defines_value_; 186 187 std::unique_ptr<CallGraph> call_graph_; 188 189 // The map of all HloValues in the module. We pass around pointers to the 190 // mapped HloValues, so the underlying container must keep them valid despite 191 // mutations touching other map entries. 192 std::unordered_map<HloValue::Id, HloValue> values_; 193 194 // A map from instruction to InstructionValueSet. 195 std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_; 196 197 // Values marked for deletion during construction. We don't delete them 198 // immediately because references to them may remain in ValueSets temporarily 199 // during propagation. After construction, these values are deleted. 200 std::vector<HloValue::Id> value_ids_to_delete_; 201 202 // A vector containing all HloValues sorted by HloValue::Id. 203 std::vector<const HloValue*> values_vector_; 204 205 // The Id to use for the next HloValue. 206 HloValue::Id next_value_id_ = 0; 207 }; 208 209 } // namespace xla 210 211 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 212