1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_value.h" 17 18 #include <algorithm> 19 #include <utility> 20 21 #include "absl/container/flat_hash_set.h" 22 #include "absl/memory/memory.h" 23 #include "absl/strings/str_cat.h" 24 #include "absl/strings/str_join.h" 25 #include "tensorflow/compiler/xla/map_util.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/status.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/types.h" 37 38 namespace xla { 39 40 using absl::StrAppend; 41 using absl::StrCat; 42 43 const Shape& HloPosition::shape() const { 44 return ShapeUtil::GetSubshape(instruction->shape(), index); 45 } 46 47 string HloPosition::ToString() const { 48 string index_str = 49 instruction->shape().IsTuple() ? (" " + index.ToString()) : ""; 50 return StrCat(instruction->name(), index_str); 51 } 52 53 std::ostream& operator<<(std::ostream& out, const HloPosition& position) { 54 out << position.ToString(); 55 return out; 56 } 57 58 string HloUse::ToString() const { 59 string index_str = instruction->operand(operand_number)->shape().IsTuple() 60 ? (" " + operand_index.ToString()) 61 : ""; 62 return StrCat(instruction->name(), ", operand ", operand_number, index_str); 63 } 64 65 std::ostream& operator<<(std::ostream& out, const HloUse& use) { 66 out << use.ToString(); 67 return out; 68 } 69 70 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, 71 const ShapeIndex& index, bool is_phi) 72 : BufferValue(instruction, index, id), is_phi_(is_phi) { 73 // The defining position is always the first element in the positions_ vector. 74 positions_.push_back(HloPosition{instruction, index}); 75 } 76 77 bool HloValue::operator==(const HloValue& other) const { 78 bool equal = defining_instruction() == other.defining_instruction() && 79 defining_index() == other.defining_index(); 80 // If the values are equal they most both be phi (or non phi). 81 CHECK(!(equal && is_phi() != other.is_phi())); 82 return equal; 83 } 84 85 bool HloValue::operator!=(const HloValue& other) const { 86 return !(*this == other); 87 } 88 89 string HloValue::ToShortString() const { 90 string index_str = defining_instruction()->shape().IsTuple() 91 ? defining_index().ToString() 92 : ""; 93 return StrCat(id(), " ", is_phi_ ? "PHI " : "", 94 defining_instruction()->name(), index_str); 95 } 96 97 string HloValue::ToString(int indent) const { 98 string indentation(indent, ' '); 99 string out = StrCat(indentation, ToShortString(), ", positions:\n"); 100 for (const HloPosition& position : positions()) { 101 StrAppend(&out, indentation, " ", position.ToString(), "\n"); 102 } 103 StrAppend(&out, indentation, " uses:\n"); 104 for (const HloUse& use : uses()) { 105 StrAppend(&out, indentation, " ", use.ToString(), "\n"); 106 } 107 return out; 108 } 109 110 namespace { 111 112 // Returns true if the instruction 'user' may use the value at the given 113 // ShapeIndex in the given operand. Generally, instruction which pass through 114 // values transparently without reading the value are not considered to use the 115 // value. 116 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, 117 const HloInstruction* user) { 118 switch (user->opcode()) { 119 case HloOpcode::kGetTupleElement: 120 case HloOpcode::kCopy: 121 // These instructions only access the top-level values of their 122 // operand. Non-top-level (nested) values are passed through 123 // transparently. 124 CHECK_EQ(operand_number, 0); 125 return index.empty(); 126 case HloOpcode::kTupleSelect: 127 // Select does not use any nested elements of its selected-from operands 128 // (operand 1 and 2) 129 CHECK_GE(operand_number, 0); 130 CHECK_LE(operand_number, 2); 131 return operand_number == 0 || index.empty(); 132 133 case HloOpcode::kDomain: 134 case HloOpcode::kTuple: 135 // These instructions always pass through their operands transparently. 136 return false; 137 138 case HloOpcode::kCall: 139 case HloOpcode::kWhile: 140 // Although call and while instructions pass through their operands, they 141 // are considered uses. 142 return true; 143 144 default: 145 return true; 146 } 147 } 148 149 } // namespace 150 151 void HloValue::SetPositionsAndComputeUses( 152 absl::Span<const HloPosition> positions) { 153 CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; 154 155 // The positions must be unique and should not contain the defining position 156 // as this is added at construction time. 157 for (const HloPosition& position_a : positions) { 158 DCHECK_NE(position_a, defining_position()); 159 for (const HloPosition& position_b : positions) { 160 if (&position_a != &position_b) { 161 DCHECK_NE(position_a, position_b); 162 } 163 } 164 } 165 166 positions_.insert(positions_.end(), positions.begin(), positions.end()); 167 168 // Gather the computation roots at which this value appears. 169 absl::flat_hash_set<HloInstruction*> root_positions; 170 for (const HloPosition& position : positions_) { 171 if (position.instruction == 172 position.instruction->parent()->root_instruction()) { 173 root_positions.insert(position.instruction); 174 } 175 } 176 177 // Build vector of HloUses for the value. 178 for (const HloPosition& position : positions_) { 179 for (HloInstruction* user : position.instruction->users()) { 180 for (int64 operand_number : user->OperandIndices(position.instruction)) { 181 // Root instructions of computations are considered to be uses whether 182 // or not the root instruction itself actually uses the value. 183 if (MayUseOperandValue(operand_number, position.index, user) || 184 ContainsKey(root_positions, user)) { 185 HloUse new_use{user, operand_number, position.index}; 186 187 // The new use must not already exist in uses_. 188 for (const HloUse& use : uses_) { 189 DCHECK_NE(use, new_use); 190 } 191 192 uses_.push_back(std::move(new_use)); 193 } 194 } 195 } 196 197 // Update liveout status of this HloValue. 198 const HloModule& module = *position.instruction->parent()->parent(); 199 if (position.instruction == 200 module.entry_computation()->root_instruction()) { 201 live_out_of_module_ = true; 202 } 203 } 204 } 205 206 std::ostream& operator<<(std::ostream& out, const HloValue& value) { 207 out << value.ToShortString(); 208 return out; 209 } 210 211 void HloValueSet::SortAndUniquifyValues() { 212 absl::c_sort(values_, HloValue::IdLessThan); 213 values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), 214 values_.end()); 215 } 216 217 string HloValueSet::ToString() const { 218 return StrCat( 219 "HloValueSet: ", 220 absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { 221 result->append(value->ToShortString()); 222 })); 223 } 224 225 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) { 226 HloValueSet union_set; 227 for (const HloValueSet* input : inputs) { 228 for (const HloValue* value : input->values()) { 229 union_set.values_.push_back(value); 230 } 231 } 232 union_set.SortAndUniquifyValues(); 233 if (*this != union_set) { 234 *this = union_set; 235 return true; 236 } 237 return false; 238 } 239 240 bool HloValueSet::AddValue(const HloValue* value) { 241 auto it = std::lower_bound(values_.begin(), values_.end(), value, 242 HloValue::IdLessThan); 243 if (it == values_.end() || (*it)->id() != value->id()) { 244 values_.insert(it, value); 245 return true; 246 } 247 return false; // already exists 248 } 249 250 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { 251 out << value_set.ToString(); 252 return out; 253 } 254 255 bool InstructionValueSet::AssignUnionOf( 256 absl::Span<const InstructionValueSet* const> inputs) { 257 CHECK_GT(inputs.size(), 0); 258 for (int i = 1; i < inputs.size(); ++i) { 259 DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); 260 } 261 bool changed = false; 262 for (auto& pair : *this) { 263 const ShapeIndex& index = pair.first; 264 HloValueSet& value_set = pair.second; 265 266 std::vector<const HloValueSet*> input_value_sets; 267 for (const InstructionValueSet* input : inputs) { 268 input_value_sets.push_back(&input->element(index)); 269 } 270 changed |= value_set.AssignUnionOf(input_value_sets); 271 } 272 273 return changed; 274 } 275 276 std::ostream& operator<<(std::ostream& out, 277 const InstructionValueSet& instruction_value_set) { 278 out << instruction_value_set.ToString(); 279 return out; 280 } 281 282 string InstructionValueSet::ToString() const { 283 string out = 284 StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); 285 ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) { 286 StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); 287 }); 288 return out; 289 } 290 291 } // namespace xla 292