1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 17 18 #include <ostream> 19 #include <utility> 20 #include <vector> 21 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/memory/memory.h" 24 #include "absl/strings/str_cat.h" 25 #include "absl/strings/str_format.h" 26 #include "absl/strings/str_join.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/compiler/xla/util.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/platform/logging.h" 35 36 namespace xla { 37 38 string BufferAlias::ToString() const { 39 return absl::StrCat("BufferAlias(", instruction_->name(), "[", 40 absl::StrJoin(index_, ","), "])"); 41 } 42 43 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { 44 out << buffer_alias.ToString(); 45 return out; 46 } 47 48 bool PointsToSet::IsAmbiguous() const { 49 bool ambiguous = false; 50 ForEachElement( 51 [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) { 52 ambiguous |= points_to.size() > 1; 53 }); 54 return ambiguous; 55 } 56 57 bool PointsToSet::IsDistinct() const { 58 bool distinct = true; 59 absl::flat_hash_set<const LogicalBuffer*> all_points_to; 60 ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) { 61 for (auto& buffer : points_to) { 62 if (all_points_to.contains(buffer)) { 63 distinct = false; 64 } 65 all_points_to.insert(buffer); 66 } 67 }); 68 return distinct; 69 } 70 71 size_t PointsToSet::size() const { 72 // Because pointed-to elements may be duplicated we have to create a flattened 73 // set and return the size. 74 return CreateFlattenedSet().size(); 75 } 76 77 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const { 78 BufferSet flat_set; 79 ForEachElement( 80 [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) { 81 flat_set.insert(buffers.begin(), buffers.end()); 82 }); 83 return flat_set; 84 } 85 86 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { 87 bool found = false; 88 ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, 89 const BufferList& pointed_to_buffers) { 90 if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) { 91 found = true; 92 } 93 }); 94 return found; 95 } 96 97 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, 98 const ShapeIndex& index) const { 99 const auto& pointed_to_buffers = element(index); 100 return absl::c_linear_search(pointed_to_buffers, &buffer); 101 } 102 103 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, 104 const ShapeIndex& index) { 105 if (ContainsBufferAtIndex(buffer, index)) { 106 return; 107 } 108 mutable_element(index)->push_back(&buffer); 109 } 110 111 const PointsToSet::SourceSet& PointsToSet::tuple_sources( 112 const ShapeIndex& index) const { 113 return tree_.element(index).tuple_sources; 114 } 115 116 void PointsToSet::add_tuple_source(const ShapeIndex& index, 117 HloInstruction* tuple) { 118 tree_.mutable_element(index)->tuple_sources.insert(tuple); 119 } 120 121 namespace { 122 // Gather fusion instructions from 'instruction' into 'fusion_instructions'. 123 void GatherFusionInstructions( 124 HloInstruction* instruction, 125 std::vector<HloInstruction*>* fusion_instructions) { 126 CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); 127 for (auto* fused : instruction->fused_instructions()) { 128 if (fused->opcode() == HloOpcode::kFusion) { 129 GatherFusionInstructions(fused, fusion_instructions); 130 } 131 } 132 fusion_instructions->push_back(instruction); 133 } 134 135 } // namespace 136 137 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>> 138 TuplePointsToAnalysis::Run(const HloModule* module) { 139 auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); 140 std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis( 141 module, logical_buffer_analysis.ConsumeValueOrDie())); 142 TF_RETURN_IF_ERROR(analysis->Analyze()); 143 return std::move(analysis); 144 } 145 146 Status TuplePointsToAnalysis::Analyze() { 147 per_instruction_.clear(); 148 per_instruction_.reserve(module_->instruction_count()); 149 150 logical_buffer_aliases_.clear(); 151 logical_buffer_aliases_.resize( 152 logical_buffer_analysis_->num_logical_buffers()); 153 154 std::vector<HloInstruction*> fusion_instructions; 155 for (auto* computation : module_->MakeNonfusionComputations()) { 156 TF_RETURN_IF_ERROR(computation->Accept(this)); 157 TF_RETURN_IF_ERROR( 158 PopulateDefinedBuffersAndAliases(computation->instructions())); 159 for (auto* instruction : computation->instructions()) { 160 if (instruction->opcode() == HloOpcode::kFusion) { 161 GatherFusionInstructions(instruction, &fusion_instructions); 162 } 163 } 164 } 165 // Run points-to analysis on fusion instructions in 'computation'. 166 for (auto* instruction : fusion_instructions) { 167 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); 168 TF_RETURN_IF_ERROR( 169 PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); 170 } 171 172 XLA_VLOG_LINES(3, ToString()); 173 174 return Status::OK(); 175 } 176 177 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype( 178 std::declval<HloComputation>().instructions())& instructions) { 179 for (auto* instruction : instructions) { 180 PerInstruction* pi = PerInst(instruction); 181 TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( 182 instruction, &pi->instruction_defined_buffers)); 183 184 const PointsToSet& points_to_set = GetPointsToSet(instruction); 185 points_to_set.ForEachElement( 186 [this, &instruction]( 187 const ShapeIndex& index, 188 const PointsToSet::BufferList& pointed_to_buffers) { 189 for (const LogicalBuffer* buffer : pointed_to_buffers) { 190 logical_buffer_aliases_[buffer->id()].emplace_back(instruction, 191 index); 192 } 193 }); 194 } 195 return Status::OK(); 196 } 197 198 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { 199 // Create trivial points-to set for instruction. Each points-to set at index i 200 // contains a single element LogicalBuffer(hlo_instruction, i). This indicates 201 // that this instruction is the source of all buffers in its own output. 202 PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction); 203 points_to_set.ForEachMutableElement( 204 [this, hlo_instruction](const ShapeIndex& index, 205 PointsToSet::BufferList* buffers) { 206 buffers->push_back( 207 &logical_buffer_analysis_->GetBuffer(hlo_instruction, index)); 208 }); 209 210 if (hlo_instruction->shape().IsTuple()) { 211 // If the hlo instruction is a tuple-shaped, then trivially the instruction 212 // itself is the source of the tuple. 213 points_to_set.add_tuple_source({}, hlo_instruction); 214 } 215 216 return Status::OK(); 217 } 218 219 Status TuplePointsToAnalysis::HandleGetTupleElement( 220 HloInstruction* get_tuple_element) { 221 // GetTupleElement forwards a pointer to a particular element of the tuple 222 // operand. 223 int64 element_index = get_tuple_element->tuple_index(); 224 225 PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element); 226 const PointsToSet& operand_points_to_set = 227 *PerInst(get_tuple_element->operand(0))->points_to_set; 228 229 // Copy the points-to set (and tuple sources) at index {element_index} of the 230 // operand to the points-to set for this GetTupleElement instruction. 231 points_to_set.ForEachMutableElement( 232 [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) { 233 // Construct an index into the operand by prepending element_index to 234 // the index for the GetTupleElement instruction's points-to set. 235 ShapeIndex src_index; 236 src_index.push_back(element_index); 237 for (auto element : target_index) { 238 src_index.push_back(element); 239 } 240 241 *points_to = operand_points_to_set.element(src_index); 242 for (HloInstruction* tuple : 243 operand_points_to_set.tuple_sources(src_index)) { 244 points_to_set.add_tuple_source(target_index, tuple); 245 } 246 }); 247 248 return Status::OK(); 249 } 250 251 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) { 252 // A kCopy instruction performs a shallow copy of the operand. The top-level 253 // buffer (index={}) is newly created, but all other buffers (in the case of a 254 // tuple shape) come from the operand 255 PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0)); 256 points_to_set.mutable_element(/*index=*/{})->clear(); 257 points_to_set.AddPointedToBuffer( 258 logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}), 259 /*index=*/{}); 260 261 return Status::OK(); 262 } 263 264 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { 265 // A kBitcast instruction aliases its operand. That is, the buffer of its 266 // result *is* the buffer of its operand, so just copy the operands points-to 267 // set. 268 CreateCopiedPointsToSet(bitcast, bitcast->operand(0)); 269 return Status::OK(); 270 } 271 272 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { 273 // A kDomain instruction aliases its operand. That is, the buffer of its 274 // result *is* the buffer of its operand, so just copy the operands points-to 275 // set. 276 CreateCopiedPointsToSet(domain, domain->operand(0)); 277 return Status::OK(); 278 } 279 280 Status TuplePointsToAnalysis::HandleAddDependency( 281 HloInstruction* add_dependency) { 282 // AddDependency just forwards the value of its zero-th operand. 283 CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0)); 284 return Status::OK(); 285 } 286 287 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { 288 // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its 289 // output. The other indices ({} and {1}) define their own buffers. 290 PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); 291 points_to_set.AddPointedToBuffer( 292 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}), 293 /*index=*/{}); 294 points_to_set.AddPointedToBuffer( 295 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}), 296 /*index=*/{1}); 297 298 const PointsToSet& operand_points_to_set = 299 GetPointsToSet(recv_done->operand(0)); 300 301 // Recursively copy the points to set of the operand tuple {0} to the output 302 // element {0}. 303 points_to_set.ForEachMutableElement( 304 [&points_to_set, &operand_points_to_set]( 305 const ShapeIndex& index, PointsToSet::BufferList* buffers) { 306 if (index.empty() || index[0] != 0) { 307 return; 308 } 309 *buffers = operand_points_to_set.element(index); 310 for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) { 311 points_to_set.add_tuple_source(index, tuple_source); 312 } 313 }); 314 return Status::OK(); 315 } 316 317 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { 318 // Send creates a tuple of {aliased operand, U32 context, token}. 319 PointsToSet& points_to_set = CreateEmptyPointsToSet(send); 320 321 // Creates the points to set for the tuple and its element at {1}. 322 auto top_buffer = points_to_set.mutable_element(ShapeIndex({})); 323 top_buffer->push_back( 324 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({}))); 325 points_to_set.add_tuple_source({}, send); 326 327 auto context_buffer = points_to_set.mutable_element(ShapeIndex({1})); 328 context_buffer->push_back( 329 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); 330 331 auto token_buffer = points_to_set.mutable_element(ShapeIndex({2})); 332 token_buffer->push_back( 333 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2}))); 334 335 // Recursively copy the points to set of the operand to output tuple {0}. 336 const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); 337 operand_points_to_set.ForEachElement( 338 [&points_to_set, &operand_points_to_set]( 339 const ShapeIndex& src_index, 340 const PointsToSet::BufferList& points_to) { 341 ShapeIndex target_index({0}); 342 for (auto element : src_index) { 343 target_index.push_back(element); 344 } 345 *points_to_set.mutable_element(target_index) = points_to; 346 347 for (HloInstruction* tuple : 348 operand_points_to_set.tuple_sources(src_index)) { 349 points_to_set.add_tuple_source(target_index, tuple); 350 } 351 }); 352 353 return Status::OK(); 354 } 355 356 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { 357 absl::Span<HloInstruction* const> operands(tuple->operands()); 358 PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); 359 points_to_set.AddPointedToBuffer( 360 logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), 361 /*index=*/{}); 362 363 // A tuple contains references to all input operands and transitively any 364 // references in those operands. 365 for (int64 i = 0; i < operands.size(); ++i) { 366 const PointsToSet& operand_points_to_set = 367 *PerInst(operands[i])->points_to_set; 368 369 // Copy the points-to set (and tuple sources) of the operand into the 370 // respective subtree of the tuple instructions points-to set. 371 operand_points_to_set.ForEachElement( 372 [&points_to_set, &operand_points_to_set, i]( 373 const ShapeIndex& src_index, 374 const PointsToSet::BufferList& points_to) { 375 ShapeIndex target_index; 376 target_index.push_back(i); 377 for (auto element : src_index) { 378 target_index.push_back(element); 379 } 380 381 *points_to_set.mutable_element(target_index) = points_to; 382 383 for (HloInstruction* tuple : 384 operand_points_to_set.tuple_sources(src_index)) { 385 points_to_set.add_tuple_source(target_index, tuple); 386 } 387 }); 388 } 389 390 points_to_set.add_tuple_source({}, tuple); 391 392 return Status::OK(); 393 } 394 395 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { 396 // Select allocates a new buffer and then shallow copies the on_true or 397 // on_false buffer into this new buffer. Which side is chosen cannot be 398 // determined statically so conservatively set the points-to set to the union 399 // of these on_true and on_false operands. 400 // 401 // First create a copy of the on_true points-to set (and tuple sources), then 402 // add in elements of the on_false points-to set (tuple sources). 403 auto on_true = tuple_select->operand(1); 404 auto on_false = tuple_select->operand(2); 405 PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true); 406 const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set; 407 points_to_set.ForEachMutableElement( 408 [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) { 409 for (const LogicalBuffer* false_buffer : 410 false_points_to_set.element(index)) { 411 points_to_set.AddPointedToBuffer(*false_buffer, index); 412 } 413 414 for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) { 415 points_to_set.add_tuple_source(index, tuple); 416 } 417 }); 418 419 // Select creates a new (top-level) buffer to store its result, so its 420 // respective element in the points-to set should contain only itself. 421 points_to_set.mutable_element({})->clear(); 422 points_to_set.AddPointedToBuffer( 423 logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}), 424 /*index=*/{}); 425 return Status::OK(); 426 } 427 428 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet( 429 const HloInstruction* hlo_instruction) const { 430 return *PerInst(hlo_instruction)->points_to_set; 431 } 432 433 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( 434 const HloInstruction* instruction) { 435 PerInstruction* pi = PerInst(instruction); 436 CHECK(pi->points_to_set == nullptr) 437 << "instruction should not have been present in the map."; 438 auto set = absl::make_unique<PointsToSet>(&instruction->shape()); 439 pi->points_to_set = std::move(set); 440 // Return *set using the iterator returned by emplace. 441 return *pi->points_to_set; 442 } 443 444 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( 445 const HloInstruction* instruction, const ShapeIndex& index) const { 446 const auto& buffers = GetPointsToSet(instruction).element(index); 447 return (buffers.size() == 1 && buffers[0]->instruction() == instruction); 448 } 449 450 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { 451 if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { 452 return FailedPrecondition( 453 "LogicalBuffer %s is ill-defined: instruction %s does not define a " 454 "buffer at that index", 455 buffer.ToString(), buffer.instruction()->name()); 456 } 457 458 if (buffer.id() < 0 || 459 buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { 460 return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", 461 buffer.ToString(), buffer.id()); 462 } 463 if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || 464 GetBuffer(buffer.id()).index() != buffer.index()) { 465 return FailedPrecondition( 466 "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", 467 buffer.ToString(), GetBuffer(buffer.id()).ToString()); 468 } 469 470 return Status::OK(); 471 } 472 473 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer( 474 LogicalBuffer::Id id) const { 475 CHECK_GE(id, 0); 476 CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers()); 477 return logical_buffer_analysis_->GetBuffer(id); 478 } 479 480 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt( 481 const HloInstruction* instruction, const ShapeIndex& index) const { 482 const auto& buffers = GetPointsToSet(instruction).element(index); 483 if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { 484 return FailedPrecondition( 485 "instruction %s does not define buffer at index {%s}", 486 instruction->name(), absl::StrJoin(index, ",")); 487 } 488 return buffers[0]; 489 } 490 491 const TuplePointsToAnalysis::BufferAliasVector& 492 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const { 493 return logical_buffer_aliases_.at(buffer.id()); 494 } 495 496 const TuplePointsToAnalysis::BufferDefinitionVector& 497 TuplePointsToAnalysis::GetBuffersDefinedByInstruction( 498 const HloInstruction* instruction) const { 499 return PerInst(instruction)->instruction_defined_buffers; 500 } 501 502 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( 503 const HloInstruction* instruction, 504 TuplePointsToAnalysis::BufferDefinitionVector* buffers) { 505 GetPointsToSet(instruction) 506 .ForEachElement([buffers, instruction]( 507 const ShapeIndex& index, 508 const PointsToSet::BufferList& source_buffers) { 509 // Add buffers which 'instruction' is the source of. 510 CHECK(!source_buffers.empty()); 511 if (source_buffers.size() == 1 && 512 source_buffers[0]->instruction() == instruction) { 513 // If this instruction is the source of this buffer the 514 // indices must match. 515 DCHECK(source_buffers[0]->index() == index); 516 buffers->push_back(source_buffers[0]); 517 } else { 518 // If the points-to set includes more than one buffer then 519 // necessarily this instruction did not produce the 520 // buffer. 521 for (const LogicalBuffer* source_buffer : source_buffers) { 522 DCHECK(source_buffer->instruction() != instruction); 523 } 524 } 525 }); 526 return Status::OK(); 527 } 528 529 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( 530 const HloInstruction* instruction, const HloInstruction* src) { 531 // PointsToSet doesn't have a copy constructor so copy over element-by-element 532 // from src PointsToSet. 533 PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); 534 const PointsToSet& src_points_to_set = GetPointsToSet(src); 535 dst_points_to_set.ForEachMutableElement( 536 [&dst_points_to_set, &src_points_to_set]( 537 const ShapeIndex& index, PointsToSet::BufferList* buffers) { 538 *buffers = src_points_to_set.element(index); 539 for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { 540 dst_points_to_set.add_tuple_source(index, tuple_source); 541 } 542 }); 543 return *PerInst(instruction)->points_to_set; 544 } 545 546 string TuplePointsToAnalysis::ToString() const { 547 string output = 548 absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); 549 for (const auto* computation : module_->MakeNonfusionComputations()) { 550 const char* entry = 551 computation == module_->entry_computation() ? "entry " : ""; 552 absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); 553 for (const HloInstruction* instruction : 554 computation->MakeInstructionPostOrder()) { 555 InstructionToString(instruction, &output); 556 if (instruction->opcode() == HloOpcode::kFusion) { 557 for (auto* fused : instruction->fused_instructions()) { 558 InstructionToString(fused, &output); 559 } 560 } 561 } 562 } 563 564 absl::StrAppend(&output, "LogicalBuffers:\n"); 565 for (const auto& b : logical_buffer_analysis_->logical_buffers()) { 566 absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); 567 for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { 568 absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); 569 } 570 } 571 return output; 572 } 573 574 void TuplePointsToAnalysis::InstructionToString( 575 const HloInstruction* instruction, string* output) const { 576 const string prefix = instruction->IsFused() ? " " : ""; 577 absl::StrAppend(output, prefix, " instruction ", 578 instruction->ToShortString(), ":\n"); 579 const PointsToSet& points_to_set = GetPointsToSet(instruction); 580 points_to_set.ForEachElement([&prefix, &output]( 581 const ShapeIndex& index, 582 const PointsToSet::BufferList& points_to) { 583 absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", 584 absl::StrJoin(points_to, ", ", 585 [](string* out, const LogicalBuffer* source) { 586 out->append(source->ToString()); 587 }), 588 "\n"); 589 }); 590 } 591 592 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( 593 const HloInstruction* operand, const ShapeIndex& index, 594 const HloInstruction* user) const { 595 CHECK(user->IsUserOf(operand)) 596 << "user: " << user->ToString() << " operand: " << operand->ToString(); 597 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { 598 // GetTupleElement instructions only access the top-level buffer of their 599 // operand. 600 return true; 601 } else if (user->opcode() == HloOpcode::kFusion && 602 user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 603 // Find fusion parameter associated with 'operand'. 604 auto it = absl::c_find_if( 605 user->fused_parameters(), [&](HloInstruction* fused_param) { 606 return user->operand(fused_param->parameter_number()) == operand; 607 }); 608 CHECK(it != user->fused_parameters().end()); 609 // Iterate through all users of all buffer aliases of the buffer in the 610 // points-to set of fusion parameter at 'index'. 611 // Return false if any uses are detected at 'index', returns true otherwise. 612 const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); 613 for (const BufferAlias& alias : GetBufferAliases(*buffer)) { 614 for (HloInstruction* alias_user : alias.instruction()->users()) { 615 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 616 alias_user)) { 617 continue; 618 } 619 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. 620 return false; 621 } 622 } 623 // Return true: found no uses of 'operand' at 'index' in 'user'. 624 return true; 625 } 626 return false; 627 } 628 629 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. 630 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) 631 // where 'user' is a user of an alias of 'instruction' at 'index', and 632 // 'operand_index' is the operand index at which the alias appears in the 633 // operand list of 'user'. 634 std::vector<std::pair<HloInstruction*, int64>> 635 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( 636 HloInstruction* instruction, const ShapeIndex& index) const { 637 std::vector<std::pair<HloInstruction*, int64>> uses; 638 const PointsToSet::BufferList& points_to = 639 GetPointsToSet(instruction).element(index); 640 for (const LogicalBuffer* buffer : points_to) { 641 for (const BufferAlias& alias : GetBufferAliases(*buffer)) { 642 for (HloInstruction* alias_user : alias.instruction()->users()) { 643 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 644 alias_user)) { 645 continue; 646 } 647 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { 648 uses.emplace_back(alias_user, op_idx); 649 } 650 } 651 } 652 } 653 return uses; 654 } 655 656 // Returns true if there is exactly one use of 'operand' at 'operand_index' 657 // in 'fusion.fused_instructions', where the singleton use is the fused 658 // root at operand index 'use_operand_index'. Returns false otherwise. 659 // 660 // REQUIRES: 'fusion' opcode is a kFusion instruction. 661 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( 662 HloInstruction* operand, const ShapeIndex& operand_index, 663 HloInstruction* fusion, const int64 use_operand_index) const { 664 CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); 665 // Check that 'operand' is unique in the operand list of 'fusion'. 666 if (fusion->OperandIndices(operand).size() > 1) { 667 return false; 668 } 669 // Find fusion parameter associated with 'operand'. 670 const auto& fused_params = fusion->fused_parameters(); 671 auto fused_param_it = 672 absl::c_find_if(fused_params, [&](HloInstruction* fused_param) { 673 return fusion->operand(fused_param->parameter_number()) == operand; 674 }); 675 if (fused_param_it == fused_params.end()) { 676 return false; 677 } 678 auto* fused_param = *fused_param_it; 679 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. 680 auto fused_param_uses = 681 GetAllUsesOfInstructionAtIndex(fused_param, operand_index); 682 // Return true iff there is exactly one use of 'operand' at 'index', and 683 // this singleton use is the fused root (at index in 'use_operand_indices'). 684 return fused_param_uses.size() == 1 && 685 fused_param_uses[0].first == fusion->fused_expression_root() && 686 fused_param_uses[0].second == use_operand_index; 687 } 688 689 // User and operand can share buffers iff both instructions emit the same shape 690 // and layout, and 'user' meets one of the following qualifications: 691 // 692 // (1) Is element-wise. Or... 693 // (2) Is a loop fusion instruction where the only use of 'operand' at 'index' 694 // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root 695 // at operand 0. Or... 696 // (3) Is a kDot -> kAdd output fusion instruction where the only use of 697 // 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused 698 // root at operand 0 or 1. Or... 699 // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 700 // 0. 701 // (5) The 'user' of 'operand' is Sort, and it is the only user. 702 // (6) The 'user' of 'operand' is TriangularSolve, it is the second operand, 703 // and it is the only user. 704 // 705 // (2) and (3) can only be determined if points-to analysis is available. 706 bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( 707 HloInstruction* operand, const ShapeIndex& operand_index, 708 HloInstruction* user, const ShapeIndex& user_index) const { 709 CHECK(user->IsUserOf(operand)) 710 << "user: " << user->ToString() << " operand: " << operand->ToString(); 711 const Shape& operand_subshape = 712 ShapeUtil::GetSubshape(operand->shape(), operand_index); 713 const Shape& user_subshape = 714 ShapeUtil::GetSubshape(user->shape(), user_index); 715 // Check that operand and user emit the same shape and layout. 716 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 717 return false; 718 } 719 if (user->opcode() == HloOpcode::kFusion) { 720 if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || 721 user->fusion_kind() == HloInstruction::FusionKind::kInput) { 722 if (user->fused_expression_root()->opcode() == 723 HloOpcode::kDynamicUpdateSlice) { 724 // Loop fusion with kDynamicUpdateSlice fused root. 725 // 726 // Returns true iff there is exactly one use of 'operand' at shape index 727 // 'operand_index', and this singleton use is the fused root at operand 728 // index 0. 729 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); 730 } else { 731 HloInstruction* fusion_param = 732 user->fused_parameter(user->operand_index(operand)); 733 return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( 734 fusion_param); 735 } 736 } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && 737 user->fused_expression_root()->opcode() == HloOpcode::kAdd) { 738 // Output fusion with kAdd fused root. 739 740 // Check if one operand of kAdd fused root is kDot or kConvolution. 741 auto* add = user->fused_expression_root(); 742 auto add_operand_it = 743 absl::c_find_if(add->operands(), [&](HloInstruction* operand) { 744 return operand->opcode() == HloOpcode::kConvolution || 745 operand->opcode() == HloOpcode::kDot; 746 }); 747 if (add_operand_it == add->operands().end()) { 748 return false; 749 } 750 auto* matched_add_operand = *add_operand_it; 751 // Calculate operand index of 'add' operand which was not matched above. 752 const int64 other_add_operand_index = 753 matched_add_operand == add->operand(0) ? 1 : 0; 754 // Returns true iff there is exactly one use of 'operand' at shape index 755 // 'operand_index', and this singleton use is the fused root (at operand 756 // index 'other_add_operand_index'). 757 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 758 other_add_operand_index); 759 } 760 } 761 if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 762 user->opcode() == HloOpcode::kScatter || 763 user->opcode() == HloOpcode::kWhile) { 764 // We eliminated other users in BufferLiveness::live_range_strictly_before, 765 // so here we just need to check that the use is at operand index 0. 766 std::vector<int64> operand_indices = user->OperandIndices(operand); 767 return operand_indices.size() == 1 && operand_indices[0] == 0; 768 } 769 if (user->opcode() == HloOpcode::kSort) { 770 // Only valid if there are no other users. 771 if (operand->users().size() != 1) { 772 return false; 773 } 774 // If we only sort keys, the output of sort is not a tuple, so we can always 775 // share the buffer. 776 if (user->operand_count() == 1) { 777 return true; 778 } 779 CHECK(!user_index.empty()); 780 // Only share with the right tuple element buffer. 781 std::vector<int64> operand_indices = user->OperandIndices(operand); 782 return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; 783 } 784 if (user->opcode() == HloOpcode::kTriangularSolve) { 785 // Only valid if there are no other users. 786 if (operand->users().size() != 1) { 787 return false; 788 } 789 std::vector<int64> operand_indices = user->OperandIndices(operand); 790 return operand_indices.size() == 1 && operand_indices[0] == 1; 791 } 792 if (user->opcode() == HloOpcode::kCall) { 793 // TODO(b/62548313): Remove when buffer assignment is module scoped and 794 // does not assign buffers to calls. 795 // Find called computation parameter associated with 'operand'. 796 const std::vector<int64> operand_indices = user->OperandIndices(operand); 797 if (operand_indices.size() > 1) { 798 return false; 799 } 800 CHECK_EQ(1, operand_indices.size()); 801 auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); 802 // Get all uses of 'operand' at 'index' in called computation. 803 auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); 804 805 // Return true iff: 806 // *) There exists exactly one use of 'operand' in called computation. 807 // *) The unique use is by the root instruction of called computation. 808 // (Note: we check the root of the called computation, because the 809 // root result buffer is required to alias with the Call result buffer). 810 // *) The root instruction of the called computation is element-wise on 811 // 'operand'. 812 auto* callee_root = user->to_apply()->root_instruction(); 813 return param_uses.size() == 1 && param_uses[0].first == callee_root && 814 callee_root->IsElementwiseOnOperand(param_uses[0].second); 815 } 816 // Loop fusions that contain transposing copies won't reach here as they have 817 // different layouts, which fails the check in the beginning of this function. 818 // 819 // Multi-output fusion will fail the check here as tuples are not considered 820 // an elementwise operation. 821 return user->IsElementwiseOnOperand(user->operand_index(operand)); 822 } 823 824 } // namespace xla 825