1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/layout_assignment.h" 17 18 #include <algorithm> 19 #include <deque> 20 #include <functional> 21 #include <map> 22 #include <memory> 23 #include <numeric> 24 #include <ostream> 25 #include <set> 26 #include <string> 27 #include <tuple> 28 29 #include "tensorflow/compiler/xla/layout_util.h" 30 #include "tensorflow/compiler/xla/map_util.h" 31 #include "tensorflow/compiler/xla/ptr_util.h" 32 #include "tensorflow/compiler/xla/service/computation_layout.h" 33 #include "tensorflow/compiler/xla/service/hlo_computation.h" 34 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 35 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 36 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 37 #include "tensorflow/compiler/xla/service/logical_buffer.h" 38 #include "tensorflow/compiler/xla/shape_layout.h" 39 #include "tensorflow/compiler/xla/shape_util.h" 40 #include "tensorflow/compiler/xla/status_macros.h" 41 #include "tensorflow/compiler/xla/statusor.h" 42 #include "tensorflow/compiler/xla/types.h" 43 #include "tensorflow/compiler/xla/util.h" 44 #include "tensorflow/compiler/xla/xla_data.pb.h" 45 #include "tensorflow/core/lib/core/errors.h" 46 #include "tensorflow/core/lib/core/status.h" 47 #include "tensorflow/core/lib/gtl/array_slice.h" 48 #include "tensorflow/core/lib/strings/str_util.h" 49 #include "tensorflow/core/lib/strings/strcat.h" 50 #include "tensorflow/core/lib/strings/stringprintf.h" 51 #include "tensorflow/core/platform/logging.h" 52 #include "tensorflow/core/platform/protobuf.h" 53 54 namespace xla { 55 56 // For now moving only one API here, but we should have a single top level 57 // anonymous namespace, instead of three or four spread all over this file. 58 namespace { 59 60 // Creates and returns a copy of the given instruction with a different 61 // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple 62 // instruction producing the copy is returned. 63 StatusOr<HloInstruction*> CreateCopyWithNewLayout( 64 const Shape& shape_with_layout, HloInstruction* instruction) { 65 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); 66 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) 67 << ShapeUtil::HumanString(shape_with_layout) << " " 68 << ShapeUtil::HumanString(instruction->shape()) 69 << " instruction: " << instruction->ToString(); 70 71 if (ShapeUtil::IsTuple(instruction->shape())) { 72 // Deep-copy tuples. 73 std::vector<HloInstruction*> element_copies; 74 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); 75 ++i) { 76 HloInstruction* gte = instruction->parent()->AddInstruction( 77 HloInstruction::CreateGetTupleElement( 78 ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, 79 i)); 80 81 // Recurse to copy each elements. 82 TF_ASSIGN_OR_RETURN( 83 HloInstruction * element_copy, 84 CreateCopyWithNewLayout( 85 ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); 86 element_copies.push_back(element_copy); 87 } 88 // Gather element copies into a tuple with a new Tuple instruction. 89 HloInstruction* tuple_copy = instruction->parent()->AddInstruction( 90 HloInstruction::CreateTuple(element_copies)); 91 LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); 92 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 93 shape_with_layout, tuple_copy->mutable_shape())); 94 return tuple_copy; 95 } else if (ShapeUtil::IsArray(instruction->shape())) { 96 HloInstruction* copy = 97 instruction->parent()->AddInstruction(HloInstruction::CreateUnary( 98 instruction->shape(), HloOpcode::kCopy, instruction)); 99 LayoutUtil::ClearLayout(copy->mutable_shape()); 100 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 101 shape_with_layout, copy->mutable_shape())); 102 103 return copy; 104 } else { 105 return FailedPrecondition( 106 "Can only copy array and tuple shaped instructions"); 107 } 108 } 109 110 // Creates a copy of the given operand if the operand's layout does not match 111 // the given layout. This copy replaces the use in the given instruction. Tuple 112 // operands will be deep-copied. 113 Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, 114 HloInstruction* instruction, 115 int64 operand_no) { 116 HloInstruction* operand = instruction->mutable_operand(operand_no); 117 TF_RET_CHECK(operand_layout.LayoutIsSet()); 118 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); 119 120 if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { 121 // Operand layout already matches our constraint. Nothing to do. 122 return Status::OK(); 123 } 124 125 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, 126 CreateCopyWithNewLayout(operand_layout.shape(), operand)); 127 128 return instruction->ReplaceOperandWith(operand_no, operand_copy); 129 } 130 131 } // namespace 132 133 std::ostream& operator<<(std::ostream& out, 134 const LayoutConstraint& constraint) { 135 out << constraint.ToString(); 136 return out; 137 } 138 139 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, 140 const LogicalBuffer& buffer, 141 bool mandatory, bool dfs) 142 : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) { 143 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); 144 } 145 146 string BufferLayoutConstraint::ToString() const { 147 return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", 148 buffer_->ToString().c_str(), 149 LayoutUtil::HumanString(layout_).c_str()); 150 } 151 152 OperandLayoutConstraint::OperandLayoutConstraint( 153 const ShapeLayout& shape_layout, const HloInstruction* instruction, 154 int64 operand_no, bool mandatory, bool dfs) 155 : LayoutConstraint(mandatory, dfs), 156 shape_layout_(shape_layout), 157 instruction_(instruction), 158 operand_no_(operand_no) { 159 CHECK(shape_layout_.LayoutIsSet()); 160 CHECK(ShapeUtil::Compatible(shape_layout.shape(), 161 instruction->operand(operand_no)->shape())) 162 << shape_layout.shape() << " is not compatible with " 163 << instruction->operand(operand_no)->shape() << " (for operand " 164 << operand_no << " of instruction " << instruction->ToString() << ")"; 165 } 166 167 string OperandLayoutConstraint::ToString() const { 168 return tensorflow::strings::Printf( 169 "OperandLayoutConstraint %s, operand %lld: %s", 170 instruction_->name().c_str(), operand_no_, 171 shape_layout_.ToString().c_str()); 172 } 173 174 string ResultLayoutConstraint::ToString() const { 175 return tensorflow::strings::Printf("ResultLayoutConstraint: %s", 176 shape_layout_.ToString().c_str()); 177 } 178 179 LayoutConstraints::LayoutConstraints( 180 const TuplePointsToAnalysis& points_to_analysis, 181 HloComputation* computation) 182 : points_to_analysis_(points_to_analysis), computation_(computation) { 183 // Gather all array-shaped logical buffers into unconstrained_buffer_ids. 184 for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers(); 185 id++) { 186 auto& buffer = points_to_analysis_.logical_buffer(id); 187 // The points to analysis is computed per module, restrict constraints to 188 // array buffers in this computation. 189 if (buffer.IsArray() && buffer.instruction()->parent() == computation) { 190 unconstrained_buffer_ids_.insert(buffer.id()); 191 } 192 } 193 } 194 195 bool LayoutConstraints::OperandBufferForwarded( 196 const HloInstruction* instruction, int64 operand_no) const { 197 // The operand is potentially forwarded if the intersection of points-to sets 198 // of the operand and the instruction is non-empty. 199 auto output_buffers = 200 points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); 201 auto operand_buffers = 202 points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) 203 .CreateFlattenedSet(); 204 for (const LogicalBuffer* output_buffer : output_buffers) { 205 if (operand_buffers.count(output_buffer) > 0) { 206 return true; 207 } 208 } 209 return false; 210 } 211 212 Status LayoutConstraints::SetBufferLayout(const Layout& layout, 213 const LogicalBuffer& buffer, 214 bool mandatory, bool dfs) { 215 VLOG(3) << "SetBufferLayout : " << buffer << " : " 216 << LayoutUtil::HumanString(layout); 217 218 TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer)); 219 if (!buffer.IsArray()) { 220 return FailedPrecondition( 221 "Layout of buffer %s cannot be constrained because buffer is not " 222 "array-shaped, has shape: %s", 223 buffer.ToString().c_str(), 224 ShapeUtil::HumanString(buffer.shape()).c_str()); 225 } 226 TF_RETURN_IF_ERROR( 227 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); 228 229 const BufferLayoutConstraint* curr_constraint = 230 GetBufferLayoutConstraint(buffer); 231 if (curr_constraint != nullptr) { 232 if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { 233 // New constraint matches existing constraint. Nothing to do. 234 return Status::OK(); 235 } 236 if (curr_constraint->mandatory()) { 237 return FailedPrecondition( 238 "Buffer %s already has the layout constraint %s, cannot add " 239 "incompatible constraint %s", 240 buffer.ToString().c_str(), 241 LayoutUtil::HumanString(curr_constraint->layout()).c_str(), 242 LayoutUtil::HumanString(layout).c_str()); 243 } 244 } 245 246 auto iter = buffer_constraints_.find(&buffer); 247 bool overwrite = iter != buffer_constraints_.end(); 248 if (!overwrite) { 249 iter = buffer_constraints_ 250 .insert(std::make_pair( 251 &buffer, 252 BufferLayoutConstraint(layout, buffer, mandatory, dfs))) 253 .first; 254 } else { 255 iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); 256 } 257 added_constraints_.push_back(&iter->second); 258 259 // Remove buffer from the set of unconstrained buffers. 260 TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 261 static_cast<int>(!overwrite)); 262 unconstrained_buffer_ids_.erase(buffer.id()); 263 264 return Status::OK(); 265 } 266 267 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, 268 const HloInstruction* instruction, 269 int64 operand_no, bool mandatory, 270 bool dfs) { 271 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " 272 << operand_no << " : " 273 << ShapeUtil::HumanStringWithLayout(shape_with_layout); 274 275 const OperandLayoutConstraint* curr_shape_layout = 276 GetOperandLayoutConstraint(instruction, operand_no); 277 if (curr_shape_layout != nullptr) { 278 if (curr_shape_layout->shape_layout().MatchesLayoutInShape( 279 shape_with_layout)) { 280 // New constraint matches existing constraint. Nothing to do. 281 return Status::OK(); 282 } 283 if (curr_shape_layout->mandatory()) { 284 return FailedPrecondition( 285 "Operand %lld of instruction %s already has a layout constraint " 286 "%s, cannot add incompatible constraint %s", 287 operand_no, instruction->name().c_str(), 288 curr_shape_layout->shape_layout().ToString().c_str(), 289 ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); 290 } 291 } 292 293 // If any buffers in the operand occur in the output of the instruction, then 294 // return an error. This case is not handled because such a constraint changes 295 // layouts beyond this immediate use and is complicated to handle. 296 if (OperandBufferForwarded(instruction, operand_no)) { 297 return FailedPrecondition( 298 "Cannot constraint layout of operand %lld of instruction %s " 299 "because instruction forwards operand's LogicalBuffer(s)", 300 operand_no, instruction->name().c_str()); 301 } 302 303 auto key = std::make_pair(instruction, operand_no); 304 auto iter = operand_constraints_.find(key); 305 if (iter == operand_constraints_.end()) { 306 auto pair = std::make_pair( 307 key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), 308 instruction, operand_no, mandatory, dfs)); 309 iter = operand_constraints_.insert(pair).first; 310 } else { 311 iter->second = 312 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, 313 operand_no, mandatory, dfs); 314 } 315 added_constraints_.push_back(&iter->second); 316 317 return Status::OK(); 318 } 319 320 Status LayoutConstraints::SetArrayOperandLayout( 321 const Layout& layout, const HloInstruction* instruction, int64 operand_no, 322 bool mandatory, bool dfs) { 323 const HloInstruction* operand = instruction->operand(operand_no); 324 TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); 325 Shape shape(operand->shape()); 326 *shape.mutable_layout() = layout; 327 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); 328 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs); 329 } 330 331 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, 332 bool dfs) { 333 VLOG(3) << "SetResultLayout : " 334 << ShapeUtil::HumanStringWithLayout(shape_with_layout); 335 336 const ShapeLayout* curr_shape_layout = ResultLayout(); 337 if (curr_shape_layout != nullptr) { 338 if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { 339 return FailedPrecondition( 340 "Result of computation %s already has the layout constraint %s, " 341 "cannot add incompatible constraint %s", 342 computation_->name().c_str(), curr_shape_layout->ToString().c_str(), 343 ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); 344 } 345 // New constraint matches existing constraint. Nothing to do. 346 return Status::OK(); 347 } 348 349 result_constraint_.reset( 350 new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs)); 351 added_constraints_.push_back(result_constraint_.get()); 352 353 return Status::OK(); 354 } 355 356 Status LayoutConstraints::SetInstructionLayout( 357 const Shape& shape_with_layout, const HloInstruction* instruction, 358 bool mandatory, bool dfs) { 359 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " 360 << ShapeUtil::HumanStringWithLayout(shape_with_layout); 361 362 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { 363 return FailedPrecondition( 364 "Instruction %s of shape %s cannot be assigned incompatible layout %s", 365 instruction->name().c_str(), 366 ShapeUtil::HumanString(instruction->shape()).c_str(), 367 ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); 368 } 369 370 // Create a BufferLayoutConstraint for each array shape in the output of the 371 // instruction. 372 return ShapeUtil::ForEachSubshapeWithStatus( 373 shape_with_layout, 374 [this, instruction, mandatory](const Shape& subshape, 375 const ShapeIndex& index) -> Status { 376 // The precondition for this method is that the instruction defines all 377 // buffers in its output. 378 auto buffers = 379 points_to_analysis_.GetPointsToSet(instruction).element(index); 380 CHECK_EQ(1, buffers.size()); 381 CHECK_EQ(buffers[0]->instruction(), instruction); 382 383 if (ShapeUtil::IsArray(subshape)) { 384 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); 385 } else { 386 return Status::OK(); 387 } 388 }); 389 } 390 391 const Layout* LayoutConstraints::BufferLayout( 392 const LogicalBuffer& buffer) const { 393 if (const auto* constraint = GetBufferLayoutConstraint(buffer)) { 394 return &constraint->layout(); 395 } 396 return nullptr; 397 } 398 399 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint( 400 const LogicalBuffer& buffer) const { 401 auto it = buffer_constraints_.find(&buffer); 402 return it == buffer_constraints_.end() ? nullptr : &it->second; 403 } 404 405 const ShapeLayout* LayoutConstraints::OperandLayout( 406 const HloInstruction* instruction, int64 operand_no) const { 407 if (const auto* constraint = 408 GetOperandLayoutConstraint(instruction, operand_no)) { 409 return &constraint->shape_layout(); 410 } 411 return nullptr; 412 } 413 414 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint( 415 const HloInstruction* instruction, int64 operand_no) const { 416 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no)); 417 return it == operand_constraints_.end() ? nullptr : &it->second; 418 } 419 420 const ShapeLayout* LayoutConstraints::ResultLayout() const { 421 return result_constraint_ ? &result_constraint_->shape_layout() : nullptr; 422 } 423 424 string LayoutConstraints::ToString() const { 425 string output; 426 tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", 427 computation_->name(), ":\n"); 428 for (auto* instruction : computation_->MakeInstructionPostOrder()) { 429 tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), 430 "\n"); 431 for (int64 i = 0; i < instruction->operand_count(); ++i) { 432 if (OperandLayout(instruction, i) != nullptr) { 433 tensorflow::strings::StrAppend( 434 &output, " operand (", i, 435 "): ", OperandLayout(instruction, i)->ToString(), "\n"); 436 } 437 } 438 for (const LogicalBuffer* buffer : 439 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { 440 if (BufferLayout(*buffer) != nullptr) { 441 tensorflow::strings::StrAppend( 442 &output, " ", buffer->ToString(), " : ", 443 LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); 444 } 445 } 446 } 447 448 if (ResultLayout() != nullptr) { 449 tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), 450 "\n"); 451 } 452 return output; 453 } 454 455 Status LayoutAssignment::AddMandatoryConstraints( 456 const ComputationLayout& computation_layout, 457 const ChannelLayoutConstraints* channel_constraints, 458 HloComputation* computation, LayoutConstraints* constraints) { 459 VLOG(3) << "Adding mandatory layout constraints to computation " 460 << computation->name(); 461 462 // Constrain layouts of instructions which define values with pre-existing 463 // layouts. 464 for (auto* instruction : computation->instructions()) { 465 Shape const* shape_with_layout = nullptr; 466 if (instruction->opcode() == HloOpcode::kInfeed) { 467 // Infeed layouts must match the layout of the original inserted 468 // instruction. 469 // TODO(b/31425034): Change infeeds to be more like parameters, with 470 // shapes in the ComputationLayout. 471 DCHECK(!LayoutUtil::IsPadded(instruction->shape())); 472 TF_RETURN_IF_ERROR( 473 constraints->SetInstructionLayout(instruction->shape(), instruction)); 474 } else if (instruction->opcode() == HloOpcode::kOutfeed) { 475 // Constrain the input to the Outfeed instruction to be the expected 476 // layout of the Outfeed. 477 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 478 instruction->outfeed_shape(), instruction, 0)); 479 } else if (instruction->opcode() == HloOpcode::kParameter) { 480 // Parameter layouts must match the respective layout in 481 // ComputationLayout. 482 shape_with_layout = 483 &computation_layout.parameter_layout(instruction->parameter_number()) 484 .shape(); 485 } 486 if (shape_with_layout != nullptr) { 487 TF_RETURN_IF_ERROR( 488 constraints->SetInstructionLayout(*shape_with_layout, instruction)); 489 } 490 491 if (instruction->opcode() == HloOpcode::kSend || 492 instruction->opcode() == HloOpcode::kRecv) { 493 CHECK(channel_constraints) 494 << "Multi-module layout assignment requires ChannelLayoutConstraints"; 495 int64 channel_id = instruction->channel_id(); 496 if (!channel_constraints->IsChannelConstrained(channel_id)) { 497 continue; 498 } 499 if (instruction->opcode() == HloOpcode::kSend) { 500 // TODO(b/68493863): Change to use SetOperandLayout(). 501 const Shape send_buffer_shape = instruction->operand(0)->shape(); 502 TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); 503 Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel( 504 send_buffer_shape, instruction->channel_id()); 505 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( 506 new_buffer_shape, instruction->operand(0))); 507 } else { 508 const Shape recv_buffer_shape = 509 ShapeUtil::GetTupleElementShape(instruction->shape(), 0); 510 TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); 511 TF_ASSIGN_OR_RETURN( 512 const LogicalBuffer* buffer, 513 constraints->points_to_analysis().GetBufferDefinedAt(instruction, 514 {0})); 515 Shape new_shape = channel_constraints->LayoutShapeForChannel( 516 recv_buffer_shape, instruction->channel_id()); 517 TF_RETURN_IF_ERROR( 518 constraints->SetBufferLayout(new_shape.layout(), *buffer)); 519 } 520 } 521 } 522 523 // Constrain layouts of instructions which call computations which have 524 // already been assigned layouts. Instructions which call computations in a 525 // parallel element-wise context (eg, map or reduce) do not need layout 526 // constraints because they operate on scalars. 527 for (auto* instruction : computation->instructions()) { 528 if (instruction->opcode() == HloOpcode::kCall) { 529 // kCall instruction operands and output must match the ComputationLayout 530 // of the called computation. 531 const ComputationLayout& called_computation_layout = 532 FindOrDie(computation_layouts_, instruction->to_apply()); 533 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( 534 called_computation_layout.result_layout().shape(), instruction)); 535 TF_RET_CHECK(instruction->operand_count() == 536 called_computation_layout.parameter_count()); 537 for (int64 i = 0; i < instruction->operand_count(); ++i) { 538 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 539 called_computation_layout.parameter_layout(i).shape(), instruction, 540 i)); 541 } 542 } else if (instruction->opcode() == HloOpcode::kWhile) { 543 // Layout of input and output of kWhile instruction must be equal and must 544 // match both input and output of body computation. Also, the input of 545 // condition computation must match kWhile layout. 546 HloComputation* body = instruction->while_body(); 547 HloComputation* condition = instruction->while_condition(); 548 const HloInstruction* init = instruction->operand(0); 549 const ComputationLayout& body_layout = 550 FindOrDie(computation_layouts_, body); 551 const ComputationLayout& condition_layout = 552 FindOrDie(computation_layouts_, condition); 553 554 // Check a few invariants irrespective of layout. 555 CHECK_EQ(1, instruction->operand_count()); 556 CHECK_EQ(1, body->num_parameters()); 557 CHECK_EQ(1, condition->num_parameters()); 558 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), 559 body_layout.parameter_shape(0))); 560 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), 561 condition_layout.parameter_shape(0))); 562 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); 563 564 // Return error if earlier layout assignment of the embedded computations 565 // has produced conflicting layouts. 566 if (!ShapeUtil::Equal(body_layout.result_shape(), 567 body_layout.parameter_shape(0))) { 568 return InternalError( 569 "Parameter and result of body computation %s of while instruction " 570 "%s have different layouts: %s vs %s", 571 body->name().c_str(), instruction->name().c_str(), 572 ShapeUtil::HumanString(body_layout.result_shape()).c_str(), 573 ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); 574 } 575 if (!ShapeUtil::Equal(body->root_instruction()->shape(), 576 condition->parameter_instruction(0)->shape())) { 577 return InternalError( 578 "Parameter of condition computation %s of while instruction " 579 "%s does not match body computation %s result: %s vs %s", 580 condition->name().c_str(), instruction->name().c_str(), 581 body->name().c_str(), 582 ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), 583 ShapeUtil::HumanString(body_layout.result_shape()).c_str()); 584 } 585 586 // Constrain the output and the operand of the while instruction to match 587 // the computations. 588 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( 589 body_layout.result_shape(), instruction)); 590 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 591 body_layout.result_shape(), instruction, 0)); 592 } else if (instruction->opcode() == HloOpcode::kConditional) { 593 // The layout of the true and false computations must match, and must 594 // be the layout of the kConditional instruction. 595 TF_RET_CHECK(instruction->operand_count() == 3); 596 597 HloComputation* true_computation = instruction->true_computation(); 598 HloComputation* false_computation = instruction->false_computation(); 599 const HloInstruction* true_operand = instruction->operand(1); 600 const HloInstruction* false_operand = instruction->operand(2); 601 602 TF_RET_CHECK(true_computation->num_parameters() == 1); 603 TF_RET_CHECK(false_computation->num_parameters() == 1); 604 ComputationLayout& true_computation_layout = 605 FindOrDie(computation_layouts_, true_computation); 606 ComputationLayout& false_computation_layout = 607 FindOrDie(computation_layouts_, false_computation); 608 609 DCHECK(ShapeUtil::Compatible(true_operand->shape(), 610 true_computation_layout.parameter_shape(0))); 611 DCHECK(ShapeUtil::Compatible( 612 false_operand->shape(), false_computation_layout.parameter_shape(0))); 613 614 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( 615 true_computation_layout.result_shape(), instruction)); 616 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 617 true_computation_layout.parameter_shape(0), instruction, 1, 618 /*mandatory=*/true)); 619 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 620 false_computation_layout.parameter_shape(0), instruction, 2, 621 /*mandatory=*/true)); 622 } else if (instruction->opcode() == HloOpcode::kCustomCall) { 623 if (!CustomCallRequiresMajorFirstLayout(instruction)) { 624 continue; 625 } 626 // Add constraints for kCustomCall instruction operands and instructions. 627 // For now we only support major-first layouts for all inputs and outputs. 628 Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( 629 instruction->shape().element_type(), 630 AsInt64Slice(instruction->shape().dimensions())); 631 TF_RETURN_IF_ERROR( 632 constraints->SetInstructionLayout(result_shape, instruction)); 633 for (int64 i = 0; i < instruction->operand_count(); ++i) { 634 const Shape& operand_shape = instruction->operand(i)->shape(); 635 // Opaque operands don't get a layout constraint. 636 if (ShapeUtil::IsOpaque(operand_shape)) { 637 continue; 638 } 639 640 Shape row_major_operand_shape = 641 ShapeUtil::MakeShapeWithDescendingLayout( 642 operand_shape.element_type(), 643 AsInt64Slice(operand_shape.dimensions())); 644 TF_RETURN_IF_ERROR(constraints->SetOperandLayout( 645 row_major_operand_shape, instruction, i)); 646 } 647 } 648 } 649 650 // Finally set the result layout to match ComputationLayout. 651 return constraints->SetResultLayout( 652 computation_layout.result_layout().shape()); 653 } 654 655 namespace { 656 657 // The operands of a call must match the layouts of parameters in the 658 // ComputationLayout, and the call instruction itself must match the result 659 // layout in the ComputationLayout. 660 Status CheckCallLayout(HloInstruction* call, 661 const ComputationLayout& computation_layout) { 662 HloComputation* computation = call->to_apply(); 663 TF_RET_CHECK(computation->num_parameters() == call->operand_count()); 664 for (int64 i = 0; i < computation->num_parameters(); ++i) { 665 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( 666 call->operand(i)->shape())); 667 } 668 TF_RET_CHECK( 669 computation_layout.result_layout().MatchesLayoutInShape(call->shape())); 670 return Status::OK(); 671 } 672 673 // Custom calls have fixed input and output layouts. 674 Status CheckCustomCallLayout(HloInstruction* custom_call) { 675 for (const HloInstruction* operand : custom_call->operands()) { 676 TF_RET_CHECK( 677 ShapeUtil::IsOpaque(operand->shape()) || 678 LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); 679 } 680 TF_RET_CHECK( 681 ShapeUtil::IsOpaque(custom_call->shape()) || 682 LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); 683 return Status::OK(); 684 } 685 686 // For a while instruction, all the following layouts must be the same: 687 // (1) init operand 688 // (2) condition computation parameter 689 // (3) body computation parameter 690 // (4) body computation result 691 // (5) while instruction result 692 Status CheckWhileLayout(HloInstruction* while_inst, 693 const ComputationLayout& condition_computation_layout, 694 const ComputationLayout& body_computation_layout) { 695 auto init_shape = while_inst->operand(0)->shape(); 696 TF_RET_CHECK( 697 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape( 698 init_shape)); 699 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape( 700 init_shape)); 701 TF_RET_CHECK( 702 body_computation_layout.result_layout().MatchesLayoutInShape(init_shape)); 703 TF_RET_CHECK( 704 LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape())); 705 return Status::OK(); 706 } 707 708 Status CheckConditionalLayout( 709 HloInstruction* instruction, 710 const ComputationLayout& true_computation_layout, 711 const ComputationLayout& false_computation_layout) { 712 HloComputation* true_computation = instruction->true_computation(); 713 HloComputation* false_computation = instruction->false_computation(); 714 const HloInstruction* true_operand = instruction->operand(1); 715 const HloInstruction* false_operand = instruction->operand(2); 716 717 TF_RET_CHECK(true_computation_layout.result_layout() == 718 false_computation_layout.result_layout()); 719 TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( 720 instruction->shape())); 721 TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( 722 true_computation->root_instruction()->shape())); 723 TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( 724 instruction->shape())); 725 TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( 726 false_computation->root_instruction()->shape())); 727 TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( 728 true_operand->shape())); 729 TF_RET_CHECK( 730 false_computation_layout.parameter_layout(0).MatchesLayoutInShape( 731 false_operand->shape())); 732 return Status::OK(); 733 } 734 735 // Fusion parameters must match the layout of the fusion instructions operands, 736 // and the root of the fusion expression must match the layout of the fusion 737 // instruction. 738 Status CheckFusionLayout(HloInstruction* fusion) { 739 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode()); 740 741 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( 742 fusion->shape(), fusion->fused_expression_root()->shape())); 743 for (int64 i = 0; i < fusion->operand_count(); ++i) { 744 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( 745 fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape())); 746 } 747 return Status::OK(); 748 } 749 750 // The layout of a parameter must match the respective layout in the 751 // computation's ComputationLayout. 752 Status CheckParameterLayout(HloInstruction* parameter, 753 const ComputationLayout& computation_layout) { 754 const ShapeLayout& parameter_layout = 755 computation_layout.parameter_layout(parameter->parameter_number()); 756 if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { 757 return InternalError( 758 "parameter instruction %s does not match layout of computation " 759 "shape: %s", 760 parameter->ToString().c_str(), parameter_layout.ToString().c_str()); 761 } 762 return Status::OK(); 763 } 764 765 // The layout of a constant instruction must match the layout of its literal. 766 Status CheckConstantLayout(HloInstruction* constant) { 767 if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(), 768 constant->shape())) { 769 return InternalError( 770 "constant instruction %s does not match the layout of its literal %s", 771 constant->ToString().c_str(), 772 ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); 773 } 774 return Status::OK(); 775 } 776 777 } // namespace 778 779 Status LayoutAssignment::CheckLayouts(HloModule* module) { 780 TF_ASSIGN_OR_RETURN(auto points_to_analysis, 781 TuplePointsToAnalysis::Run(module)); 782 for (auto* computation : module->MakeNonfusionComputations()) { 783 for (auto* instruction : computation->instructions()) { 784 // Verify every instruction has a layout and the layout is valid for the 785 // shape. 786 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); 787 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); 788 789 // Use points-to analysis to verify that every subshape element in the 790 // output of the instruction matches the layout of the logical buffer 791 // which could be the source of the subshape value. 792 const PointsToSet& points_to_set = 793 points_to_analysis->GetPointsToSet(instruction); 794 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( 795 [&instruction](ShapeIndex index, 796 const PointsToSet::BufferList& buffers) -> Status { 797 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) { 798 const Shape& instruction_subshape = 799 ShapeUtil::GetSubshape(instruction->shape(), index); 800 for (const LogicalBuffer* buffer : buffers) { 801 if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) { 802 return InternalError( 803 "Layout of instruction %s at index {%s} does not match " 804 "source LogicalBuffer %s: %s vs %s", 805 instruction->name().c_str(), 806 tensorflow::str_util::Join(index, ",").c_str(), 807 buffer->ToString().c_str(), 808 ShapeUtil::HumanStringWithLayout(instruction_subshape) 809 .c_str(), 810 ShapeUtil::HumanStringWithLayout(buffer->shape()) 811 .c_str()); 812 } 813 } 814 } 815 return Status::OK(); 816 })); 817 818 // Verify instructions that have special layout constraints. 819 switch (instruction->opcode()) { 820 case HloOpcode::kCall: 821 TF_RETURN_IF_ERROR(CheckCallLayout( 822 instruction, 823 FindOrDie(computation_layouts_, instruction->to_apply()))); 824 break; 825 case HloOpcode::kCustomCall: 826 if (CustomCallRequiresMajorFirstLayout(instruction)) { 827 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); 828 } 829 break; 830 case HloOpcode::kFusion: 831 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); 832 break; 833 case HloOpcode::kParameter: 834 TF_RETURN_IF_ERROR(CheckParameterLayout( 835 instruction, 836 FindOrDie(computation_layouts_, instruction->parent()))); 837 break; 838 case HloOpcode::kConstant: 839 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); 840 break; 841 case HloOpcode::kWhile: 842 TF_RETURN_IF_ERROR(CheckWhileLayout( 843 instruction, 844 FindOrDie(computation_layouts_, instruction->while_condition()), 845 FindOrDie(computation_layouts_, instruction->while_body()))); 846 break; 847 case HloOpcode::kConditional: 848 TF_RETURN_IF_ERROR(CheckConditionalLayout( 849 instruction, 850 FindOrDie(computation_layouts_, instruction->true_computation()), 851 FindOrDie(computation_layouts_, 852 instruction->false_computation()))); 853 break; 854 default: 855 break; 856 } 857 } 858 } 859 860 // Finally verify the result layout matches the layout of the entry 861 // computation root. 862 TF_RET_CHECK(ShapeUtil::Equal( 863 module->entry_computation()->root_instruction()->shape(), 864 FindOrDie(computation_layouts_, module->entry_computation()) 865 .result_layout() 866 .shape())); 867 868 return Status::OK(); 869 } 870 871 LayoutAssignment::LayoutAssignment( 872 ComputationLayout* entry_computation_layout, 873 ChannelLayoutConstraints* channel_constraints) 874 : entry_computation_layout_(entry_computation_layout), 875 channel_layout_constraints_(channel_constraints) { 876 VLOG(1) << "entry computation layout given to layout assignment: " 877 << entry_computation_layout_->ToString(); 878 // Layouts of all parameter instructions must be set. 879 for (const ShapeLayout& parameter_layout : 880 entry_computation_layout_->parameter_layouts()) { 881 CHECK(parameter_layout.LayoutIsSet()); 882 } 883 // If the result layout is not set, then choose the default. 884 // TODO(b/29118294): Choose a better layout in this case. 885 if (!entry_computation_layout_->result_layout().LayoutIsSet()) { 886 entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); 887 } 888 } 889 890 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( 891 const Layout& output_layout, const HloInstruction* instruction, 892 int64 operand_no) { 893 const HloInstruction* operand = instruction->operand(operand_no); 894 895 CHECK(ShapeUtil::IsArray(instruction->shape())); 896 CHECK(ShapeUtil::IsArray(operand->shape())); 897 898 if (instruction->IsElementwiseOnOperand(operand_no) && 899 !ShapeUtil::IsScalar(operand->shape()) && 900 ShapeUtil::Rank(operand->shape()) == 901 ShapeUtil::Rank(instruction->shape())) { 902 // Assign operands the same layout as the instruction, so that 903 // 1) the elementwise operation can reuse its operand's buffer, and 904 // 2) the input and output elements can reuse the same linear index. 905 // 906 // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit 907 // from assigning the same layout to input and output. 908 return MakeUnique<Layout>(output_layout); 909 } 910 911 if (instruction->opcode() == HloOpcode::kReshape) { 912 // Prefer the operand layout that makes the reshape an bitcast. If any 913 // dimension bound is 1 in the operand shape, there may be several such 914 // layouts. So if 'output_layout' is the default layout, try if the 915 // reshape is a bitcast when using the same layout. This may avoid copy 916 // operations. For similar reasons, if the operand and output have the same 917 // rank, try to match the operand's layout to the output. 918 if (ShapeUtil::TrueRank(operand->shape()) == 1 && 919 ShapeUtil::Rank(instruction->shape()) == 1) { 920 // Don't assign a layout in case of R1 -> effective R1 reshape. 921 return nullptr; 922 } 923 const Shape& output_shape = instruction->shape(); 924 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( 925 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), 926 LayoutUtil::MinorToMajor(output_layout)); 927 Shape operand_shape = operand->shape(); 928 *operand_shape.mutable_layout() = 929 LayoutUtil::GetDefaultLayoutForShape(operand_shape); 930 if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { 931 return MakeUnique<Layout>(operand_shape.layout()); 932 } 933 if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { 934 *operand_shape.mutable_layout() = output_layout; 935 if (ShapeUtil::ReshapeIsBitcast(operand_shape, 936 output_shape_with_layout)) { 937 return MakeUnique<Layout>(output_layout); 938 } 939 } 940 auto aligned_operand_shape = 941 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); 942 if (aligned_operand_shape) { 943 auto operand_layout = aligned_operand_shape.value().layout(); 944 TF_CHECK_OK( 945 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); 946 return MakeUnique<Layout>(operand_layout); 947 } 948 } 949 950 if (instruction->opcode() == HloOpcode::kTranspose) { 951 // Pick the operand layout that makes the transpose a bitcast. 952 int64 rank = ShapeUtil::Rank(instruction->shape()); 953 std::vector<int64> new_minor_to_major(rank); 954 for (int64 i = 0; i < rank; ++i) { 955 int64 output_dim = LayoutUtil::Minor(output_layout, i); 956 int64 operand_dim = instruction->dimensions(output_dim); 957 new_minor_to_major[i] = operand_dim; 958 } 959 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); 960 TF_CHECK_OK( 961 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); 962 return MakeUnique<Layout>(operand_layout); 963 } 964 965 return nullptr; 966 } 967 968 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( 969 const Layout& operand_layout, const HloInstruction* user, 970 int64 operand_no) { 971 const HloInstruction* operand = user->operand(operand_no); 972 973 CHECK(ShapeUtil::IsArray(user->shape()) && 974 ShapeUtil::IsArray(operand->shape())); 975 976 if (user->IsElementwiseOnOperand(operand_no) && 977 !ShapeUtil::IsScalar(operand->shape()) && 978 ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { 979 // Assign users the same layout as the operand. 980 return MakeUnique<Layout>(operand_layout); 981 } 982 983 if (user->opcode() == HloOpcode::kReshape) { 984 // Prefer the user layout that makes the reshape an bitcast. If any 985 // dimension bound is 1 in the user shape, there may be several such 986 // layouts. So if 'operand_layout' is the default layout, try if the 987 // reshape is a bitcast when using the same layout. This may avoid copy 988 // operations. For similar reasons, if the operand and output have the same 989 // rank, try to match the outputs's layout to the operand. 990 if (ShapeUtil::Rank(operand->shape()) == 1 && 991 ShapeUtil::TrueRank(user->shape()) == 1) { 992 // Don't assign a layout in case of R1 -> effective R1 reshape. 993 return nullptr; 994 } 995 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( 996 operand->shape().element_type(), 997 AsInt64Slice(operand->shape().dimensions()), 998 LayoutUtil::MinorToMajor(operand_layout)); 999 Shape output_shape = user->shape(); 1000 *output_shape.mutable_layout() = 1001 LayoutUtil::GetDefaultLayoutForShape(output_shape); 1002 if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { 1003 return MakeUnique<Layout>(output_shape.layout()); 1004 } 1005 if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { 1006 *output_shape.mutable_layout() = operand_layout; 1007 if (ShapeUtil::ReshapeIsBitcast(output_shape, 1008 operand_shape_with_layout)) { 1009 return MakeUnique<Layout>(operand_layout); 1010 } 1011 } 1012 auto aligned_user_shape = 1013 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); 1014 if (aligned_user_shape) { 1015 auto user_layout = aligned_user_shape.value().layout(); 1016 TF_CHECK_OK( 1017 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); 1018 return MakeUnique<Layout>(user_layout); 1019 } 1020 } 1021 1022 if (user->opcode() == HloOpcode::kTranspose) { 1023 // Pick the user layout that makes the transpose a bitcast. 1024 int64 rank = ShapeUtil::Rank(user->shape()); 1025 std::vector<int64> new_minor_to_major(rank); 1026 auto inverse_dimensions = InversePermutation(user->dimensions()); 1027 for (int64 i = 0; i < rank; ++i) { 1028 int64 operand_dim = LayoutUtil::Minor(operand_layout, i); 1029 int64 user_dim = inverse_dimensions[operand_dim]; 1030 new_minor_to_major[i] = user_dim; 1031 } 1032 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); 1033 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); 1034 return MakeUnique<Layout>(user_layout); 1035 } 1036 1037 return nullptr; 1038 } 1039 1040 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) { 1041 // Gathers all initial constraints in a worklist and propagates them in 1042 // depth-first order. DFS order seems to be better than BFS because a 1043 // constraint is propagated as far as possible before propagating unrelated 1044 // constraints which makes it less likely that conflicting constraints will be 1045 // propagated to instructions. However, we should experiment with other orders 1046 // too. 1047 std::deque<const LayoutConstraint*> worklist; 1048 1049 // Lambda for moving newly added constraints to the worklist. 1050 auto add_new_constraints_to_worklist = [constraints, &worklist]() { 1051 // Add constraints to the front of the deque for DFS ordering. 1052 for (auto* constraint : constraints->ConsumeAddedConstraints()) { 1053 if (constraint->dfs()) { 1054 worklist.push_front(constraint); 1055 } else { 1056 worklist.push_back(constraint); 1057 } 1058 } 1059 }; 1060 add_new_constraints_to_worklist(); 1061 1062 while (!worklist.empty()) { 1063 const LayoutConstraint* layout_constraint = worklist.front(); 1064 worklist.pop_front(); 1065 VLOG(2) << "Propagating " << layout_constraint->ToString() 1066 << " to its neighbors."; 1067 if (auto* buffer_constraint = 1068 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) { 1069 TF_RETURN_IF_ERROR( 1070 PropagateBufferConstraint(*buffer_constraint, constraints)); 1071 } else if (auto* operand_constraint = 1072 dynamic_cast<const OperandLayoutConstraint*>( 1073 layout_constraint)) { 1074 TF_RETURN_IF_ERROR( 1075 PropagateOperandConstraint(*operand_constraint, constraints)); 1076 } else if (auto* result_constraint = 1077 dynamic_cast<const ResultLayoutConstraint*>( 1078 layout_constraint)) { 1079 TF_RETURN_IF_ERROR( 1080 PropagateResultConstraint(*result_constraint, constraints)); 1081 } else { 1082 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint; 1083 } 1084 1085 add_new_constraints_to_worklist(); 1086 } 1087 return Status::OK(); 1088 } 1089 1090 namespace { 1091 1092 // Returns a vector containing all array-shaped uses (instruction and operand 1093 // number) of the given logical buffer or its aliases. 1094 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer( 1095 const LogicalBuffer& buffer, 1096 const TuplePointsToAnalysis& points_to_analysis) { 1097 CHECK(buffer.IsArray()); 1098 std::vector<std::pair<const HloInstruction*, int64>> uses; 1099 for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) { 1100 if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) { 1101 continue; 1102 } 1103 // This alias must be the top-level (index == {}) of the instruction's 1104 // result because the instruction produces an array. 1105 CHECK(buffer_alias.index().empty()); 1106 1107 // Add all uses of the instruction's output. 1108 for (const HloInstruction* user : buffer_alias.instruction()->users()) { 1109 for (int64 operand_no : 1110 user->OperandIndices(buffer_alias.instruction())) { 1111 uses.emplace_back(user, operand_no); 1112 } 1113 } 1114 } 1115 return uses; 1116 } 1117 1118 } // namespace 1119 1120 Status LayoutAssignment::PropagateUseConstraintToDefs( 1121 const ShapeLayout& shape_layout, const HloInstruction* instruction, 1122 LayoutConstraints* constraints) { 1123 // Try to set all logical buffers which may be sources of the given operand to 1124 // match the given layout. 1125 const PointsToSet& points_to_set = 1126 constraints->points_to_analysis().GetPointsToSet(instruction); 1127 return points_to_set.ForEachElementWithStatus( 1128 [this, &shape_layout, constraints]( 1129 const ShapeIndex& index, 1130 const PointsToSet::BufferList& buffers) -> Status { 1131 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { 1132 for (const LogicalBuffer* buffer : buffers) { 1133 if (constraints->BufferLayout(*buffer) == nullptr && 1134 ShapeUtil::IsArray(buffer->shape())) { 1135 TF_RETURN_IF_ERROR(constraints->SetBufferLayout( 1136 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), 1137 *buffer, /*mandatory=*/true)); 1138 } 1139 } 1140 } 1141 return Status::OK(); 1142 }); 1143 } 1144 1145 Status LayoutAssignment::PropagateOperandConstraint( 1146 const OperandLayoutConstraint& operand_constraint, 1147 LayoutConstraints* constraints) { 1148 // Try to set the layout of the logical buffers in the given operand to match 1149 // the constrained layout. This avoids copies. 1150 TF_RETURN_IF_ERROR( 1151 PropagateUseConstraintToDefs(operand_constraint.shape_layout(), 1152 operand_constraint.operand(), constraints)); 1153 1154 // For array-shaped operands and user instructions try to pick a minimum cost 1155 // layout. For example, if the operand of a elementwise instruction is 1156 // constained to a certain layout we want the output of the instruction to 1157 // have the same layout. 1158 const HloInstruction* operand = operand_constraint.operand(); 1159 const HloInstruction* user = operand_constraint.instruction(); 1160 if (!ShapeUtil::IsArray(operand->shape()) || 1161 !ShapeUtil::IsArray(user->shape())) { 1162 return Status::OK(); 1163 } 1164 1165 // Only try to choose a low cost layout if the instruction 'user' defines its 1166 // output (ie, doesn't forward a buffer from elsewhere). 1167 if (constraints->OperandBufferForwarded(user, 1168 operand_constraint.operand_no())) { 1169 return Status::OK(); 1170 } 1171 TF_ASSIGN_OR_RETURN( 1172 const LogicalBuffer* buffer, 1173 constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{})); 1174 1175 if (constraints->BufferLayout(*buffer) == nullptr) { 1176 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout( 1177 operand_constraint.shape_layout().layout(), user, 1178 operand_constraint.operand_no()); 1179 if (layout != nullptr) { 1180 TF_RETURN_IF_ERROR( 1181 constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false)); 1182 } 1183 } 1184 return Status::OK(); 1185 } 1186 1187 Status LayoutAssignment::PropagateBufferConstraint( 1188 const BufferLayoutConstraint& buffer_constraint, 1189 LayoutConstraints* constraints) { 1190 // Only propagate array layouts. 1191 const LogicalBuffer& buffer = buffer_constraint.buffer(); 1192 if (!buffer.IsArray()) { 1193 return Status::OK(); 1194 } 1195 1196 // If this buffer is the result of an array-shaped op (as opposed to an array 1197 // element in a tuple) try to propagate the layout to its operands. 1198 if (buffer.IsTopLevel()) { 1199 const HloInstruction* instruction = buffer.instruction(); 1200 // Propagate the def-constraint on an instruction to the use-constraints on 1201 // its operands (use-def propagation). 1202 for (int64 operand_no = 0; operand_no < instruction->operand_count(); 1203 ++operand_no) { 1204 if (constraints->OperandLayout(instruction, operand_no) == nullptr && 1205 ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { 1206 std::unique_ptr<Layout> operand_layout = 1207 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), 1208 instruction, operand_no); 1209 if (operand_layout != nullptr) { 1210 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( 1211 *operand_layout, instruction, operand_no, /*mandatory=*/true)); 1212 } 1213 } 1214 } 1215 } 1216 return PropagateBufferConstraintToUses(buffer_constraint, constraints); 1217 } 1218 1219 Status LayoutAssignment::PropagateBufferConstraintToUses( 1220 const BufferLayoutConstraint& buffer_constraint, 1221 LayoutConstraints* constraints) { 1222 const LogicalBuffer& buffer = buffer_constraint.buffer(); 1223 TF_RET_CHECK(buffer.IsArray()); 1224 1225 // Propagate the layout to all array uses of the logical buffer. This skips 1226 // uses of the buffer where the buffer is the element of a tuple. 1227 for (const auto& user_operand_no : 1228 GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) { 1229 const HloInstruction* user = user_operand_no.first; 1230 int64 operand_no = user_operand_no.second; 1231 // Only add an operand constraint if the user does not forward the buffer 1232 // because this case is not handled is SetOperandLayout. 1233 if (constraints->OperandLayout(user, operand_no) == nullptr && 1234 !constraints->OperandBufferForwarded(user, operand_no)) { 1235 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( 1236 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false)); 1237 } 1238 } 1239 1240 return Status::OK(); 1241 } 1242 1243 Status LayoutAssignment::PropagateResultConstraint( 1244 const ResultLayoutConstraint& result_constraint, 1245 LayoutConstraints* constraints) { 1246 // Propagate the use constraint of the root instruction up to the logical 1247 // buffers which make up the result. 1248 return PropagateUseConstraintToDefs( 1249 result_constraint.shape_layout(), 1250 constraints->computation()->root_instruction(), constraints); 1251 } 1252 1253 namespace { 1254 1255 // Infers the layout of the array at the given index in the given instruction's 1256 // output using points-to analysis. Precondition: The given instruction must 1257 // not produce this array value (that is, the array is forwarded from the 1258 // instruction's operands). 1259 StatusOr<Layout> InferArrayLayout( 1260 const TuplePointsToAnalysis& points_to_analysis, 1261 HloInstruction* instruction, const ShapeIndex& index) { 1262 // This function should only be called for array shapes which don't yet have 1263 // layouts. 1264 const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index); 1265 TF_RET_CHECK(ShapeUtil::IsArray(subshape)); 1266 TF_RET_CHECK(!subshape.has_layout()); 1267 1268 // The instruction should not define the buffer at this index. 1269 TF_RET_CHECK( 1270 !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index)) 1271 << instruction->ToString(); 1272 1273 const auto& source_buffers = 1274 points_to_analysis.GetPointsToSet(instruction).element(index); 1275 TF_RET_CHECK(!source_buffers.empty()); 1276 1277 // Verify the layout is the same for every LogicalBuffer which this location 1278 // ('instruction' and 'index') points to. 1279 const Layout* first_buffer_layout = nullptr; 1280 for (const LogicalBuffer* source_buffer : source_buffers) { 1281 if (!source_buffer->shape().has_layout()) { 1282 // This should not happen because we've assigned layouts to all 1283 // instructions preceding this one. 1284 return InternalError("LogicalBuffer %s does not have a layout", 1285 source_buffer->ToString().c_str()); 1286 } 1287 1288 if (first_buffer_layout == nullptr) { 1289 first_buffer_layout = &source_buffer->shape().layout(); 1290 } else if (!LayoutUtil::Equal(source_buffer->shape().layout(), 1291 *first_buffer_layout)) { 1292 // The points-to set is ambiguous for this index and the different source 1293 // buffers have different layouts. This case is possible in valid XLA 1294 // computations because we do not propagate BufferLayoutConstraints to all 1295 // LogicalBuffers which may alias the constrained LogicalBuffer at some 1296 // point in the computation. 1297 return FailedPrecondition( 1298 "Array at index {%s} in instruction %s aliases buffers %s " 1299 "and %s which have different layouts", 1300 tensorflow::str_util::Join(index, ",").c_str(), 1301 instruction->name().c_str(), source_buffers[0]->ToString().c_str(), 1302 source_buffer->ToString().c_str()); 1303 } 1304 } 1305 1306 return *first_buffer_layout; 1307 } 1308 1309 // For fusion instructions, set the layout of each fused parameter instruction 1310 // to match the layout of its corresponding fusion instruction operand. Also, 1311 // set the layout of the fused root to match the layout of the fusion 1312 // instruction itself. 1313 Status SetFusionLayouts(HloInstruction* fusion) { 1314 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); 1315 for (auto* fused_instruction : 1316 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { 1317 if (fused_instruction->opcode() == HloOpcode::kParameter) { 1318 const HloInstruction* fusion_operand = 1319 fusion->operand(fused_instruction->parameter_number()); 1320 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(), 1321 fused_instruction->shape())); 1322 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 1323 fusion_operand->shape(), fused_instruction->mutable_shape())); 1324 } else if (fused_instruction == fusion->fused_expression_root()) { 1325 // The layout of the root of the fused expression must match the fusion 1326 // instruction layout. 1327 DCHECK( 1328 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape())); 1329 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 1330 fusion->shape(), fused_instruction->mutable_shape())); 1331 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) { 1332 // A GTE inherits its layout from its operand (which should ultimately be 1333 // a parameter). 1334 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 1335 fused_instruction->operand(0)->shape().tuple_shapes( 1336 fused_instruction->tuple_index()), 1337 fused_instruction->mutable_shape())); 1338 } else if (fused_instruction->opcode() == HloOpcode::kConstant) { 1339 // Give constants the layout of their literal. 1340 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 1341 fused_instruction->literal().shape(), 1342 fused_instruction->mutable_shape())); 1343 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { 1344 // Nop; leave the infeed layout alone. 1345 } else { 1346 // Other instructions don't have layouts inside of fusion nodes. 1347 LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); 1348 } 1349 } 1350 1351 return Status::OK(); 1352 } 1353 1354 } // namespace 1355 1356 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, 1357 HloComputation* computation) { 1358 VLOG(2) << "Assigning layouts to computation: " << computation->name(); 1359 XLA_VLOG_LINES(2, computation->ToString()); 1360 XLA_VLOG_LINES(2, constraints.ToString()); 1361 1362 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { 1363 LayoutUtil::ClearLayout(instruction->mutable_shape()); 1364 1365 // Create a copy of an operand if the operand instruction's layout does not 1366 // match the use constraint (OperandLayoutConstraint). 1367 for (int64 operand_no = 0; operand_no < instruction->operand_count(); 1368 ++operand_no) { 1369 const ShapeLayout* operand_layout = 1370 constraints.OperandLayout(instruction, operand_no); 1371 if (operand_layout != nullptr) { 1372 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, 1373 instruction, operand_no)); 1374 } 1375 } 1376 1377 // Set the layouts of the array shapes this instruction defines as indicated 1378 // by the respective BufferLayoutConstraints. Any array shapes in the output 1379 // of the instruction which are not defined by the instruction (eg, array 1380 // elements in a Tuple instruction) will be assigned below via inference. 1381 for (const LogicalBuffer* buffer : 1382 constraints.points_to_analysis().GetBuffersDefinedByInstruction( 1383 instruction)) { 1384 if (!ShapeUtil::IsArray(buffer->shape())) { 1385 continue; 1386 } 1387 1388 TF_RET_CHECK(buffer->instruction() == instruction); 1389 const Layout* buffer_layout = constraints.BufferLayout(*buffer); 1390 TF_RET_CHECK(buffer_layout != nullptr); 1391 1392 if (instruction->opcode() == HloOpcode::kConstant) { 1393 // For constants, we also need to change the layout of the internal 1394 // literal. 1395 instruction->RelayoutConstant(*buffer_layout, buffer->index()); 1396 } else { 1397 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape( 1398 instruction->mutable_shape(), buffer->index()); 1399 *buffer_subshape->mutable_layout() = *buffer_layout; 1400 } 1401 } 1402 1403 // Any remaining layouts in the output of the instruction must be 1404 // inferrable using points-to analysis. 1405 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( 1406 instruction->mutable_shape(), 1407 [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { 1408 if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { 1409 return Status::OK(); 1410 } 1411 // Set Layout of subshape to match layout of LogicalBuffer which 1412 // produces it. 1413 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(), 1414 InferArrayLayout(constraints.points_to_analysis(), 1415 instruction, index)); 1416 return Status::OK(); 1417 })); 1418 1419 // Fusion instructions require some layouts to be set on fused instructions 1420 // inside the fusion instruction. 1421 if (instruction->opcode() == HloOpcode::kFusion) { 1422 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction)); 1423 } 1424 1425 // Execute extra verification step once the layout has been finalized. 1426 TF_RETURN_IF_ERROR(Verify(instruction)); 1427 1428 // Verify all layouts in the shape have been set. 1429 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); 1430 } 1431 1432 // Copy the root instruction's result if its layout does not match the result 1433 // layout constraint. 1434 if (constraints.ResultLayout() != nullptr && 1435 !constraints.ResultLayout()->MatchesLayoutInShape( 1436 computation->root_instruction()->shape())) { 1437 TF_ASSIGN_OR_RETURN( 1438 HloInstruction * new_root, 1439 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), 1440 computation->root_instruction())); 1441 computation->set_root_instruction(new_root); 1442 } 1443 1444 return Status::OK(); 1445 } 1446 1447 Status LayoutAssignment::RunOnComputation( 1448 const ComputationLayout& computation_layout, 1449 const TuplePointsToAnalysis& points_to_analysis, 1450 HloComputation* computation, 1451 ChannelLayoutConstraints* channel_constraints) { 1452 DCHECK(computation_layout.LayoutIsSet()); 1453 InsertOrDie(&computation_layouts_, computation, computation_layout); 1454 VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() 1455 << ")"; 1456 VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); 1457 1458 // Construct LayoutConstraints with all layout constraints of the computation. 1459 LayoutConstraints constraints(points_to_analysis, computation); 1460 1461 // Add constraints required for correctness on all backends (eg, entry 1462 // parameter layout constraints). 1463 TF_RETURN_IF_ERROR(AddMandatoryConstraints( 1464 computation_layout, channel_constraints, computation, &constraints)); 1465 1466 // Add any backend-specific constraints. 1467 TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); 1468 1469 // Propagates layouts from mandatory and backend constraints. 1470 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); 1471 1472 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a 1473 // layout and propagate the change. 1474 while (!constraints.unconstrained_buffer_ids().empty()) { 1475 int unconstrained_count = constraints.unconstrained_buffer_ids().size(); 1476 1477 // Arbitrarily pick the first unconstrained buffer and give it the default 1478 // layout (or the literal layout, in case of constants). By construction 1479 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id. 1480 const LogicalBuffer& buffer = points_to_analysis.GetBuffer( 1481 *constraints.unconstrained_buffer_ids().begin()); 1482 const HloInstruction* instruction = buffer.instruction(); 1483 Layout new_layout = 1484 instruction->opcode() == HloOpcode::kConstant 1485 ? ShapeUtil::GetSubshape(instruction->literal().shape(), 1486 buffer.index()) 1487 .layout() 1488 : LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); 1489 TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer, 1490 /*mandatory=*/false)); 1491 1492 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); 1493 1494 // To verify progress has been made, check that the number of unconstrained 1495 // buffers has been reduced. 1496 CHECK_LT(constraints.unconstrained_buffer_ids().size(), 1497 unconstrained_count); 1498 } 1499 1500 // All logical buffers should have constraints at this point. All that 1501 // remains is assign the constraints to the buffers and infer layouts for 1502 // aliased buffers. 1503 TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); 1504 1505 // Record the layouts assigned for any communication ops in 1506 // channel_constraints so that they are constrained for future modules. 1507 for (HloInstruction* instruction : computation->instructions()) { 1508 if (instruction->opcode() == HloOpcode::kSend) { 1509 channel_constraints->ConstrainChannel( 1510 instruction->channel_id(), instruction->operand(0)->shape().layout()); 1511 } else if (instruction->opcode() == HloOpcode::kRecvDone) { 1512 channel_constraints->ConstrainChannel(instruction->channel_id(), 1513 instruction->shape().layout()); 1514 } 1515 } 1516 return Status::OK(); 1517 } 1518 1519 StatusOr<bool> LayoutAssignment::Run(HloModule* module) { 1520 VLOG(2) << "Running layout assignment on module " << module->name(); 1521 XLA_VLOG_LINES(3, module->ToString()); 1522 if (VLOG_IS_ON(10)) { 1523 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 1524 "before layout assignment", 1525 module->config().debug_options()); 1526 } 1527 1528 TF_ASSIGN_OR_RETURN(auto points_to_analysis, 1529 TuplePointsToAnalysis::Run(module)); 1530 1531 // Assign layouts to computations in an order such that a callee computation 1532 // is handled before its caller computation. This ensures that the layout of 1533 // all callers of a computation will agree. 1534 std::list<HloComputation*> computation_post_order = 1535 module->MakeComputationPostOrder(); 1536 for (auto* computation : module->MakeComputationPostOrder()) { 1537 if (computation->IsFusionComputation()) { 1538 continue; 1539 } 1540 // Clear existing layouts of the instructions. All layouts must be assigned 1541 // by the LayoutAssignment pass, except for those on infeeds, parameters, 1542 // and the computation result. The latter two are specified in 1543 // computation_layout, so we only need to keep the existing layouts for 1544 // infeeds. Clearing the layouts here avoids hiding potential bugs in the 1545 // layout assignment pass that may accidently use the existing layout. 1546 for (HloInstruction* instruction : computation->instructions()) { 1547 if (instruction->opcode() != HloOpcode::kInfeed) { 1548 LayoutUtil::ClearLayout(instruction->mutable_shape()); 1549 } 1550 } 1551 if (computation == module->entry_computation()) { 1552 TF_RETURN_IF_ERROR(RunOnComputation( 1553 *entry_computation_layout_, *points_to_analysis, 1554 module->entry_computation(), channel_layout_constraints_)); 1555 } else { 1556 ComputationLayout computation_layout(computation->ComputeProgramShape()); 1557 // Setting all embedded computations to the default layout is potentially 1558 // suboptimal. 1559 computation_layout.SetToDefaultLayout(); 1560 TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, 1561 *points_to_analysis, computation, 1562 channel_layout_constraints_)); 1563 } 1564 } 1565 1566 TF_RETURN_IF_ERROR(CheckLayouts(module)); 1567 1568 VLOG(3) << "After layout assignment:"; 1569 XLA_VLOG_LINES(3, module->ToString()); 1570 if (VLOG_IS_ON(10)) { 1571 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 1572 "after layout assignment", 1573 module->config().debug_options()); 1574 } 1575 1576 // All layouts are reset then reassigned by this pass. 1577 return true; 1578 } 1579 1580 } // namespace xla 1581