1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/liveness_util.h" 17 18 #include <algorithm> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/logical_buffer.h" 24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/util.h" 28 29 namespace xla { 30 31 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 32 const ShapeIndex& index, 33 const HloInstruction* user, 34 const TuplePointsToAnalysis& points_to_analysis) { 35 CHECK(user->IsUserOf(operand)) 36 << "user: " << user->ToString() << " operand: " << operand->ToString(); 37 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { 38 // GetTupleElement instructions only access the top-level buffer of their 39 // operand. 40 return true; 41 } else if (user->opcode() == HloOpcode::kFusion && 42 user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 43 // Find fusion parameter associated with 'operand'. 44 auto it = std::find_if( 45 user->fused_parameters().begin(), user->fused_parameters().end(), 46 [=](HloInstruction* fused_param) { 47 return user->operand(fused_param->parameter_number()) == operand; 48 }); 49 CHECK(it != user->fused_parameters().end()); 50 // Iterate through all users of all buffer aliases of the buffer in the 51 // points-to set of fusion parameter at 'index'. 52 // Return false if any uses are detected at 'index', returns true otherwise. 53 const LogicalBuffer* buffer = 54 points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); 55 for (const BufferAlias& alias : 56 points_to_analysis.GetBufferAliases(*buffer)) { 57 for (HloInstruction* alias_user : alias.instruction()->users()) { 58 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 59 alias_user, points_to_analysis)) { 60 continue; 61 } 62 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. 63 return false; 64 } 65 } 66 // Return true: found no uses of 'operand' at 'index' in 'user'. 67 return true; 68 } 69 return false; 70 } 71 72 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 73 const ShapeIndex& index, 74 const HloInstruction* user, 75 const HloDataflowAnalysis& dataflow) { 76 CHECK(user->IsUserOf(operand)) 77 << "user: " << user->ToString() << " operand: " << operand->ToString(); 78 if (user->opcode() == HloOpcode::kFusion && 79 user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 80 // Find fusion parameter associated with 'operand'. 81 HloInstruction* fusion_param = 82 user->fused_parameter(user->operand_index(operand)); 83 // Iterate through all users of all uses of the fusion parameter value. 84 // Return false if any uses are detected, returns true otherwise. 85 const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); 86 return value.uses().empty(); 87 } else { 88 // Return false if no value at 'operand' and 'index' is used at 'user'. 89 for (const HloValue* value : 90 dataflow.GetValueSet(operand, index).values()) { 91 for (const HloUse& use : value->uses()) { 92 if (use.instruction == user) { 93 return false; 94 } 95 } 96 } 97 } 98 99 return true; 100 } 101 102 namespace { 103 104 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. 105 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) 106 // where 'user' is a user of an alias of 'instruction' at 'index', and 107 // 'operand_index' is the operand index at which the alias appears in the 108 // operand list of 'user'. 109 std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( 110 HloInstruction* instruction, const ShapeIndex& index, 111 const TuplePointsToAnalysis& points_to_analysis) { 112 std::vector<std::pair<HloInstruction*, int64>> uses; 113 const PointsToSet::BufferList& points_to = 114 points_to_analysis.GetPointsToSet(instruction).element(index); 115 for (const LogicalBuffer* buffer : points_to) { 116 for (const BufferAlias& alias : 117 points_to_analysis.GetBufferAliases(*buffer)) { 118 for (HloInstruction* alias_user : alias.instruction()->users()) { 119 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 120 alias_user, points_to_analysis)) { 121 continue; 122 } 123 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { 124 uses.emplace_back(alias_user, op_idx); 125 } 126 } 127 } 128 } 129 return uses; 130 } 131 132 // Returns true if there is exactly one use of 'operand' at 'operand_index' 133 // in 'fusion.fused_instructions', where the singleton use is the fused 134 // root at operand index 'use_operand_index'. Returns false otherwise. 135 // 136 // REQUIRES: 'fusion' opcode is a kFusion instruction. 137 bool HasUniqueFusedUseOfOperandAt( 138 HloInstruction* operand, const ShapeIndex& operand_index, 139 HloInstruction* fusion, const int64 use_operand_index, 140 const TuplePointsToAnalysis& points_to_analysis) { 141 CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); 142 // Check that 'operand' is unique in the operand list of 'fusion'. 143 if (fusion->OperandIndices(operand).size() > 1) { 144 return false; 145 } 146 // Find fusion parameter associated with 'operand'. 147 const auto& fused_params = fusion->fused_parameters(); 148 auto fused_param_it = std::find_if( 149 fused_params.begin(), fused_params.end(), 150 [&](HloInstruction* fused_param) { 151 return fusion->operand(fused_param->parameter_number()) == operand; 152 }); 153 if (fused_param_it == fused_params.end()) { 154 return false; 155 } 156 auto* fused_param = *fused_param_it; 157 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. 158 auto fused_param_uses = GetAllUsesOfInstructionAtIndex( 159 fused_param, operand_index, points_to_analysis); 160 // Return true iff there is exactly one use of 'operand' at 'index', and 161 // this singleton use is the fused root (at index in 'use_operand_indices'). 162 return fused_param_uses.size() == 1 && 163 fused_param_uses[0].first == fusion->fused_expression_root() && 164 fused_param_uses[0].second == use_operand_index; 165 } 166 167 } // namespace 168 169 // User and operand can share buffers iff both instructions emit the same shape 170 // and layout, and 'user' meets one of the following qualifications: 171 // 172 // (1) Is element-wise. Or... 173 // (2) Is a loop fusion instruction where the only use of 'operand' at 'index' 174 // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root 175 // at operand 0. Or... 176 // (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion 177 // instruction where the only use of 'operand' at 'index' in the set 178 // 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... 179 // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 180 // 0. 181 // 182 // (2) and (3) can only be determined if points-to analysis is available. 183 bool CanShareOperandBufferWithUser( 184 HloInstruction* operand, const ShapeIndex& operand_index, 185 HloInstruction* user, const ShapeIndex& user_index, 186 const TuplePointsToAnalysis& points_to_analysis) { 187 CHECK(user->IsUserOf(operand)) 188 << "user: " << user->ToString() << " operand: " << operand->ToString(); 189 const Shape& operand_subshape = 190 ShapeUtil::GetSubshape(operand->shape(), operand_index); 191 const Shape& user_subshape = 192 ShapeUtil::GetSubshape(user->shape(), user_index); 193 // Check that operand and user emit the same shape and layout. 194 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 195 return false; 196 } 197 if (user->opcode() == HloOpcode::kFusion) { 198 if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && 199 user->fused_expression_root()->opcode() == 200 HloOpcode::kDynamicUpdateSlice) { 201 // Loop fusion with kDynamicUpdateSlice fused root. 202 // 203 // Returns true iff there is exactly one use of 'operand' at shape index 204 // 'operand_index', and this singleton use is the fused root at operand 205 // index 0. 206 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, 207 points_to_analysis); 208 } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && 209 user->fused_expression_root()->opcode() == HloOpcode::kAdd) { 210 // Output fusion with kAdd fused root. 211 212 // Check if one operand of kAdd fused root is either kDot, or nested 213 // kFusion of kind kTransposeDot. 214 auto* add = user->fused_expression_root(); 215 auto add_operand_it = 216 std::find_if(add->operands().begin(), add->operands().end(), 217 [&](HloInstruction* operand) { 218 return operand->opcode() == HloOpcode::kConvolution || 219 operand->opcode() == HloOpcode::kDot || 220 (operand->opcode() == HloOpcode::kFusion && 221 operand->fusion_kind() == 222 HloInstruction::FusionKind::kTransposeDot); 223 }); 224 if (add_operand_it == add->operands().end()) { 225 return false; 226 } 227 auto* matched_add_operand = *add_operand_it; 228 // Calculate operand index of 'add' operand which was not matched above. 229 const int64 other_add_operand_index = 230 matched_add_operand == add->operand(0) ? 1 : 0; 231 // Returns true iff there is exactly one use of 'operand' at shape index 232 // 'operand_index', and this singleton use is the fused root (at operand 233 // index 'other_add_operand_index'). 234 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 235 other_add_operand_index, 236 points_to_analysis); 237 } 238 } 239 if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 240 user->opcode() == HloOpcode::kWhile) { 241 // We eliminated other users in BufferLiveness::live_range_strictly_before, 242 // so here we just need to check that the use is at operand index 0. 243 std::vector<int64> operand_indices = user->OperandIndices(operand); 244 return operand_indices.size() == 1 && operand_indices[0] == 0; 245 } 246 if (user->opcode() == HloOpcode::kCall) { 247 // TODO(b/62548313): Remove when buffer assignment is module scoped and 248 // does not assign buffers to calls. 249 // Find called computation parameter associated with 'operand'. 250 const std::vector<int64> operand_indices = user->OperandIndices(operand); 251 if (operand_indices.size() > 1) { 252 return false; 253 } 254 CHECK_EQ(1, operand_indices.size()); 255 auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); 256 // Get all uses of 'operand' at 'index' in called computation. 257 auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, 258 points_to_analysis); 259 260 // Return true iff: 261 // *) There exists exactly one use of 'operand' in called computation. 262 // *) The unique use is by the root instruction of called computation. 263 // (Note: we check the root of the called computation, because the 264 // root result buffer is required to alias with the Call result buffer). 265 // *) The root instruction of the called computation is element-wise on 266 // 'operand'. 267 auto* callee_root = user->to_apply()->root_instruction(); 268 return param_uses.size() == 1 && param_uses[0].first == callee_root && 269 callee_root->IsElementwiseOnOperand(param_uses[0].second); 270 } 271 // Check if 'user' is element-wise. 272 return user->IsElementwise(); 273 } 274 275 bool CanShareOperandBufferWithUser(HloInstruction* operand, 276 const ShapeIndex& operand_index, 277 HloInstruction* user, 278 const ShapeIndex& user_index, 279 const HloDataflowAnalysis& dataflow) { 280 CHECK(user->IsUserOf(operand)) 281 << "user: " << user->ToString() << " operand: " << operand->ToString(); 282 const Shape& operand_subshape = 283 ShapeUtil::GetSubshape(operand->shape(), operand_index); 284 const Shape& user_subshape = 285 ShapeUtil::GetSubshape(user->shape(), user_index); 286 // Check that operand and user emit the same shape and layout. 287 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 288 return false; 289 } 290 291 if (user->opcode() == HloOpcode::kFusion) { 292 // Get the parameter associated with 'operand'; 293 HloInstruction* fusion_param = 294 user->fused_parameter(user->operand_index(operand)); 295 296 const HloValue& value = 297 dataflow.GetValueDefinedAt(fusion_param, operand_index); 298 if (value.uses().size() != 1) { 299 return false; 300 } 301 const HloUse& use = value.uses()[0]; 302 303 if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && 304 user->fused_expression_root()->opcode() == 305 HloOpcode::kDynamicUpdateSlice) { 306 // Loop fusion with kDynamicUpdateSlice fused root. 307 // 308 // Returns true iff there is exactly one use of 'operand' at shape index 309 // 'operand_index', and this singleton use is the fused root at operand 310 // index 0. 311 return use.instruction == user->fused_expression_root() && 312 use.operand_number == 0; 313 } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && 314 user->fused_expression_root()->opcode() == HloOpcode::kAdd) { 315 // Output fusion with kAdd fused root. 316 317 // Check if one operand of kAdd fused root is either kDot, or nested 318 // kFusion of kind kTransposeDot. 319 auto* add = user->fused_expression_root(); 320 auto add_operand_it = 321 std::find_if(add->operands().begin(), add->operands().end(), 322 [&](HloInstruction* operand) { 323 return operand->opcode() == HloOpcode::kConvolution || 324 operand->opcode() == HloOpcode::kDot || 325 (operand->opcode() == HloOpcode::kFusion && 326 operand->fusion_kind() == 327 HloInstruction::FusionKind::kTransposeDot); 328 }); 329 if (add_operand_it == add->operands().end()) { 330 return false; 331 } 332 auto* matched_add_operand = *add_operand_it; 333 // Calculate operand index of 'add' operand which was not matched above. 334 const int64 other_add_operand_index = 335 matched_add_operand == add->operand(0) ? 1 : 0; 336 // Returns true iff there is exactly one use of 'operand' at shape index 337 // 'operand_index', and this singleton use is the fused root (at operand 338 // index 'other_add_operand_index'). 339 return use.instruction == user->fused_expression_root() && 340 use.operand_number == other_add_operand_index; 341 } 342 } 343 if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 344 user->opcode() == HloOpcode::kWhile) { 345 // We eliminated other users in BufferLiveness::live_range_strictly_before, 346 // so here we just need to check that the use is at operand index 0. 347 std::vector<int64> operand_indices = user->OperandIndices(operand); 348 return operand_indices.size() == 1 && operand_indices[0] == 0; 349 } 350 if (user->opcode() == HloOpcode::kCall) { 351 // Get all uses of value defined by 'operand' at 'operand_index'. 352 const auto& uses = 353 dataflow.GetValueDefinedAt(operand, operand_index).uses(); 354 // Return true iff: 355 // *) There exists two uses of 'operand'. 356 // *) One use is by 'user' (caller). 357 // *) One use is by root instruction of called computation (callee root). 358 // (Note: we check the root of the called computation, because the 359 // root result buffer is required to alias with the Call result buffer). 360 // *) The root instruction of the called computation is element-wise on 361 // 'operand'. 362 const bool found_caller_use = 363 std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { 364 return use.instruction == user; 365 }) != uses.end(); 366 auto* callee_root = user->to_apply()->root_instruction(); 367 const bool found_elementwise_callee_use = 368 std::find_if( 369 uses.begin(), uses.end(), [callee_root](const HloUse& use) { 370 return use.instruction == callee_root && 371 callee_root->IsElementwiseOnOperand(use.operand_number); 372 }) != uses.end(); 373 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; 374 } 375 // Check if 'user' is element-wise. 376 return user->IsElementwise(); 377 } 378 379 } // namespace xla 380