1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_value.h" 17 18 #include <algorithm> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/map_util.h" 22 #include "tensorflow/compiler/xla/ptr_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/status.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/util.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/platform/logging.h" 35 36 namespace xla { 37 38 using ::tensorflow::str_util::Join; 39 using ::tensorflow::strings::StrAppend; 40 using ::tensorflow::strings::StrCat; 41 42 const Shape& HloPosition::shape() const { 43 return ShapeUtil::GetSubshape(instruction->shape(), index); 44 } 45 46 string HloPosition::ToString() const { 47 string index_str = 48 ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; 49 return StrCat(instruction->name(), index_str); 50 } 51 52 std::ostream& operator<<(std::ostream& out, const HloPosition& position) { 53 out << position.ToString(); 54 return out; 55 } 56 57 string HloUse::ToString() const { 58 string index_str = 59 ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) 60 ? (" " + operand_index.ToString()) 61 : ""; 62 return StrCat(instruction->name(), ", operand ", operand_number, index_str); 63 } 64 65 std::ostream& operator<<(std::ostream& out, const HloUse& use) { 66 out << use.ToString(); 67 return out; 68 } 69 70 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, 71 const ShapeIndex& index, bool is_phi) 72 : id_(id), is_phi_(is_phi) { 73 // The defining position is always the first element in the positions_ vector. 74 positions_.push_back(HloPosition{instruction, index}); 75 } 76 77 bool HloValue::operator==(const HloValue& other) const { 78 bool equal = defining_instruction() == other.defining_instruction() && 79 defining_index() == other.defining_index(); 80 // If the values are equal they most both be phi (or non phi). 81 CHECK(!(equal && is_phi() != other.is_phi())); 82 return equal; 83 } 84 85 bool HloValue::operator!=(const HloValue& other) const { 86 return !(*this == other); 87 } 88 89 string HloValue::ToShortString() const { 90 string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) 91 ? defining_index().ToString() 92 : ""; 93 return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), 94 index_str); 95 } 96 97 string HloValue::ToString(int indent) const { 98 string indentation(indent, ' '); 99 string out = StrCat(indentation, ToShortString(), ", positions:\n"); 100 for (const HloPosition& position : positions()) { 101 StrAppend(&out, indentation, " ", position.ToString(), "\n"); 102 } 103 StrAppend(&out, indentation, " uses:\n"); 104 for (const HloUse& use : uses()) { 105 StrAppend(&out, indentation, " ", use.ToString(), "\n"); 106 } 107 return out; 108 } 109 110 namespace { 111 112 // Returns true if the instruction 'user' may use the value at the given 113 // ShapeIndex in the given operand. Generally, instruction which pass through 114 // values transparently without reading the value are not considered to use the 115 // value. 116 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, 117 const HloInstruction* user) { 118 switch (user->opcode()) { 119 case HloOpcode::kGetTupleElement: 120 case HloOpcode::kCopy: 121 // These instructions only access the top-level values of their 122 // operand. Non-top-level (nested) values are passed through 123 // transparently. 124 CHECK_EQ(operand_number, 0); 125 return index.empty(); 126 case HloOpcode::kSelect: 127 // Select does not use any nested elements of its selected-from operands 128 // (operand 1 and 2) 129 CHECK_GE(operand_number, 0); 130 CHECK_LE(operand_number, 2); 131 return operand_number == 0 || index.empty(); 132 133 case HloOpcode::kTuple: 134 // These instructions always pass through their operands transparently. 135 return false; 136 137 case HloOpcode::kCall: 138 case HloOpcode::kWhile: 139 // Although call and while instructions pass through their operands, they 140 // are considered uses. 141 return true; 142 143 default: 144 return true; 145 } 146 } 147 148 } // namespace 149 150 void HloValue::SetPositionsAndComputeUses( 151 tensorflow::gtl::ArraySlice<HloPosition> positions) { 152 CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; 153 154 // The positions must be unique and should not contain the defining position 155 // as this is added at construction time. 156 for (const HloPosition& position_a : positions) { 157 DCHECK_NE(position_a, defining_position()); 158 for (const HloPosition& position_b : positions) { 159 if (&position_a != &position_b) { 160 DCHECK_NE(position_a, position_b); 161 } 162 } 163 } 164 165 positions_.insert(positions_.end(), positions.begin(), positions.end()); 166 167 // Gather the computation roots at which this value appears. 168 tensorflow::gtl::FlatSet<HloInstruction*> root_positions; 169 for (const HloPosition& position : positions_) { 170 if (position.instruction == 171 position.instruction->parent()->root_instruction()) { 172 root_positions.insert(position.instruction); 173 } 174 } 175 176 // Build vector of HloUses for the value. 177 for (const HloPosition& position : positions_) { 178 for (HloInstruction* user : position.instruction->users()) { 179 for (int64 operand_number : user->OperandIndices(position.instruction)) { 180 // Root instructions of computations are considered to be uses whether 181 // or not the root instruction itself actually uses the value. 182 if (MayUseOperandValue(operand_number, position.index, user) || 183 ContainsKey(root_positions, user)) { 184 HloUse new_use{user, operand_number, position.index}; 185 186 // The new use must not already exist in uses_. 187 for (const HloUse& use : uses_) { 188 DCHECK_NE(use, new_use); 189 } 190 191 uses_.push_back(std::move(new_use)); 192 } 193 } 194 } 195 196 // Update liveout status of this HloValue. 197 const HloModule& module = *position.instruction->parent()->parent(); 198 if (position.instruction == 199 module.entry_computation()->root_instruction()) { 200 live_out_of_module_ = true; 201 } 202 } 203 } 204 205 std::ostream& operator<<(std::ostream& out, const HloValue& value) { 206 out << value.ToShortString(); 207 return out; 208 } 209 210 void HloValueSet::SortAndUniquifyValues() { 211 std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); 212 values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), 213 values_.end()); 214 } 215 216 string HloValueSet::ToString() const { 217 return StrCat("HloValueSet: ", 218 Join(values_, ", ", [](string* result, const HloValue* value) { 219 result->append(value->ToShortString()); 220 })); 221 } 222 223 bool HloValueSet::AssignUnionOf( 224 tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) { 225 HloValueSet union_set; 226 for (const HloValueSet* input : inputs) { 227 for (const HloValue* value : input->values()) { 228 union_set.values_.push_back(value); 229 } 230 } 231 union_set.SortAndUniquifyValues(); 232 if (*this != union_set) { 233 *this = union_set; 234 return true; 235 } 236 return false; 237 } 238 239 bool HloValueSet::AddValue(const HloValue* value) { 240 auto it = std::lower_bound(values_.begin(), values_.end(), value, 241 HloValue::IdLessThan); 242 if (it == values_.end() || (*it)->id() != value->id()) { 243 values_.insert(it, value); 244 return true; 245 } 246 return false; // already exists 247 } 248 249 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { 250 out << value_set.ToString(); 251 return out; 252 } 253 254 bool InstructionValueSet::AssignUnionOf( 255 tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) { 256 CHECK_GT(inputs.size(), 0); 257 for (int i = 1; i < inputs.size(); ++i) { 258 DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); 259 } 260 bool changed = false; 261 for (auto& pair : *this) { 262 const ShapeIndex& index = pair.first; 263 HloValueSet& value_set = pair.second; 264 265 std::vector<const HloValueSet*> input_value_sets; 266 for (const InstructionValueSet* input : inputs) { 267 input_value_sets.push_back(&input->element(index)); 268 } 269 changed |= value_set.AssignUnionOf(input_value_sets); 270 } 271 272 return changed; 273 } 274 275 std::ostream& operator<<(std::ostream& out, 276 const InstructionValueSet& instruction_value_set) { 277 out << instruction_value_set.ToString(); 278 return out; 279 } 280 281 string InstructionValueSet::ToString() const { 282 string out = 283 StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); 284 ForEachElement([this, &out](const ShapeIndex& index, 285 const HloValueSet& value_set) { 286 StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); 287 }); 288 return out; 289 } 290 291 } // namespace xla 292