Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
     17 
     18 #include <algorithm>
     19 #include <deque>
     20 #include <functional>
     21 #include <map>
     22 #include <memory>
     23 #include <numeric>
     24 #include <ostream>
     25 #include <set>
     26 #include <string>
     27 #include <tuple>
     28 
     29 #include "tensorflow/compiler/xla/layout_util.h"
     30 #include "tensorflow/compiler/xla/map_util.h"
     31 #include "tensorflow/compiler/xla/ptr_util.h"
     32 #include "tensorflow/compiler/xla/service/computation_layout.h"
     33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     34 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     37 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     38 #include "tensorflow/compiler/xla/shape_layout.h"
     39 #include "tensorflow/compiler/xla/shape_util.h"
     40 #include "tensorflow/compiler/xla/status_macros.h"
     41 #include "tensorflow/compiler/xla/statusor.h"
     42 #include "tensorflow/compiler/xla/types.h"
     43 #include "tensorflow/compiler/xla/util.h"
     44 #include "tensorflow/compiler/xla/xla_data.pb.h"
     45 #include "tensorflow/core/lib/core/errors.h"
     46 #include "tensorflow/core/lib/core/status.h"
     47 #include "tensorflow/core/lib/gtl/array_slice.h"
     48 #include "tensorflow/core/lib/strings/str_util.h"
     49 #include "tensorflow/core/lib/strings/strcat.h"
     50 #include "tensorflow/core/lib/strings/stringprintf.h"
     51 #include "tensorflow/core/platform/logging.h"
     52 #include "tensorflow/core/platform/protobuf.h"
     53 
     54 namespace xla {
     55 
     56 // For now moving only one API here, but we should have a single top level
     57 // anonymous namespace, instead of three or four spread all over this file.
     58 namespace {
     59 
     60 // Creates and returns a copy of the given instruction with a different
     61 // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
     62 // instruction producing the copy is returned.
     63 StatusOr<HloInstruction*> CreateCopyWithNewLayout(
     64     const Shape& shape_with_layout, HloInstruction* instruction) {
     65   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
     66   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
     67       << ShapeUtil::HumanString(shape_with_layout) << " "
     68       << ShapeUtil::HumanString(instruction->shape())
     69       << " instruction: " << instruction->ToString();
     70 
     71   if (ShapeUtil::IsTuple(instruction->shape())) {
     72     // Deep-copy tuples.
     73     std::vector<HloInstruction*> element_copies;
     74     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
     75          ++i) {
     76       HloInstruction* gte = instruction->parent()->AddInstruction(
     77           HloInstruction::CreateGetTupleElement(
     78               ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
     79               i));
     80 
     81       // Recurse to copy each elements.
     82       TF_ASSIGN_OR_RETURN(
     83           HloInstruction * element_copy,
     84           CreateCopyWithNewLayout(
     85               ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
     86       element_copies.push_back(element_copy);
     87     }
     88     // Gather element copies into a tuple with a new Tuple instruction.
     89     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
     90         HloInstruction::CreateTuple(element_copies));
     91     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
     92     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
     93         shape_with_layout, tuple_copy->mutable_shape()));
     94     return tuple_copy;
     95   } else if (ShapeUtil::IsArray(instruction->shape())) {
     96     HloInstruction* copy =
     97         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
     98             instruction->shape(), HloOpcode::kCopy, instruction));
     99     LayoutUtil::ClearLayout(copy->mutable_shape());
    100     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
    101         shape_with_layout, copy->mutable_shape()));
    102 
    103     return copy;
    104   } else {
    105     return FailedPrecondition(
    106         "Can only copy array and tuple shaped instructions");
    107   }
    108 }
    109 
    110 // Creates a copy of the given operand if the operand's layout does not match
    111 // the given layout. This copy replaces the use in the given instruction. Tuple
    112 // operands will be deep-copied.
    113 Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
    114                                   HloInstruction* instruction,
    115                                   int64 operand_no) {
    116   HloInstruction* operand = instruction->mutable_operand(operand_no);
    117   TF_RET_CHECK(operand_layout.LayoutIsSet());
    118   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
    119 
    120   if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
    121     // Operand layout already matches our constraint. Nothing to do.
    122     return Status::OK();
    123   }
    124 
    125   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
    126                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
    127 
    128   return instruction->ReplaceOperandWith(operand_no, operand_copy);
    129 }
    130 
    131 }  // namespace
    132 
    133 std::ostream& operator<<(std::ostream& out,
    134                          const LayoutConstraint& constraint) {
    135   out << constraint.ToString();
    136   return out;
    137 }
    138 
    139 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
    140                                                const LogicalBuffer& buffer,
    141                                                bool mandatory, bool dfs)
    142     : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
    143   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
    144 }
    145 
    146 string BufferLayoutConstraint::ToString() const {
    147   return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s",
    148                                      buffer_->ToString().c_str(),
    149                                      LayoutUtil::HumanString(layout_).c_str());
    150 }
    151 
    152 OperandLayoutConstraint::OperandLayoutConstraint(
    153     const ShapeLayout& shape_layout, const HloInstruction* instruction,
    154     int64 operand_no, bool mandatory, bool dfs)
    155     : LayoutConstraint(mandatory, dfs),
    156       shape_layout_(shape_layout),
    157       instruction_(instruction),
    158       operand_no_(operand_no) {
    159   CHECK(shape_layout_.LayoutIsSet());
    160   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
    161                               instruction->operand(operand_no)->shape()))
    162       << shape_layout.shape() << " is not compatible with "
    163       << instruction->operand(operand_no)->shape() << " (for operand "
    164       << operand_no << " of instruction " << instruction->ToString() << ")";
    165 }
    166 
    167 string OperandLayoutConstraint::ToString() const {
    168   return tensorflow::strings::Printf(
    169       "OperandLayoutConstraint %s, operand %lld: %s",
    170       instruction_->name().c_str(), operand_no_,
    171       shape_layout_.ToString().c_str());
    172 }
    173 
    174 string ResultLayoutConstraint::ToString() const {
    175   return tensorflow::strings::Printf("ResultLayoutConstraint: %s",
    176                                      shape_layout_.ToString().c_str());
    177 }
    178 
    179 LayoutConstraints::LayoutConstraints(
    180     const TuplePointsToAnalysis& points_to_analysis,
    181     HloComputation* computation)
    182     : points_to_analysis_(points_to_analysis), computation_(computation) {
    183   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
    184   for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers();
    185        id++) {
    186     auto& buffer = points_to_analysis_.logical_buffer(id);
    187     // The points to analysis is computed per module, restrict constraints to
    188     // array buffers in this computation.
    189     if (buffer.IsArray() && buffer.instruction()->parent() == computation) {
    190       unconstrained_buffer_ids_.insert(buffer.id());
    191     }
    192   }
    193 }
    194 
    195 bool LayoutConstraints::OperandBufferForwarded(
    196     const HloInstruction* instruction, int64 operand_no) const {
    197   // The operand is potentially forwarded if the intersection of points-to sets
    198   // of the operand and the instruction is non-empty.
    199   auto output_buffers =
    200       points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet();
    201   auto operand_buffers =
    202       points_to_analysis_.GetPointsToSet(instruction->operand(operand_no))
    203           .CreateFlattenedSet();
    204   for (const LogicalBuffer* output_buffer : output_buffers) {
    205     if (operand_buffers.count(output_buffer) > 0) {
    206       return true;
    207     }
    208   }
    209   return false;
    210 }
    211 
    212 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
    213                                           const LogicalBuffer& buffer,
    214                                           bool mandatory, bool dfs) {
    215   VLOG(3) << "SetBufferLayout : " << buffer << " : "
    216           << LayoutUtil::HumanString(layout);
    217 
    218   TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
    219   if (!buffer.IsArray()) {
    220     return FailedPrecondition(
    221         "Layout of buffer %s cannot be constrained because buffer is not "
    222         "array-shaped, has shape: %s",
    223         buffer.ToString().c_str(),
    224         ShapeUtil::HumanString(buffer.shape()).c_str());
    225   }
    226   TF_RETURN_IF_ERROR(
    227       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
    228 
    229   const BufferLayoutConstraint* curr_constraint =
    230       GetBufferLayoutConstraint(buffer);
    231   if (curr_constraint != nullptr) {
    232     if (LayoutUtil::Equal(curr_constraint->layout(), layout)) {
    233       // New constraint matches existing constraint. Nothing to do.
    234       return Status::OK();
    235     }
    236     if (curr_constraint->mandatory()) {
    237       return FailedPrecondition(
    238           "Buffer %s already has the layout constraint %s, cannot add "
    239           "incompatible constraint %s",
    240           buffer.ToString().c_str(),
    241           LayoutUtil::HumanString(curr_constraint->layout()).c_str(),
    242           LayoutUtil::HumanString(layout).c_str());
    243     }
    244   }
    245 
    246   auto iter = buffer_constraints_.find(&buffer);
    247   bool overwrite = iter != buffer_constraints_.end();
    248   if (!overwrite) {
    249     iter = buffer_constraints_
    250                .insert(std::make_pair(
    251                    &buffer,
    252                    BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
    253                .first;
    254   } else {
    255     iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
    256   }
    257   added_constraints_.push_back(&iter->second);
    258 
    259   // Remove buffer from the set of unconstrained buffers.
    260   TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) ==
    261                static_cast<int>(!overwrite));
    262   unconstrained_buffer_ids_.erase(buffer.id());
    263 
    264   return Status::OK();
    265 }
    266 
    267 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
    268                                            const HloInstruction* instruction,
    269                                            int64 operand_no, bool mandatory,
    270                                            bool dfs) {
    271   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
    272           << operand_no << " : "
    273           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
    274 
    275   const OperandLayoutConstraint* curr_shape_layout =
    276       GetOperandLayoutConstraint(instruction, operand_no);
    277   if (curr_shape_layout != nullptr) {
    278     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
    279             shape_with_layout)) {
    280       // New constraint matches existing constraint. Nothing to do.
    281       return Status::OK();
    282     }
    283     if (curr_shape_layout->mandatory()) {
    284       return FailedPrecondition(
    285           "Operand %lld of instruction %s already has a layout constraint "
    286           "%s, cannot add incompatible constraint %s",
    287           operand_no, instruction->name().c_str(),
    288           curr_shape_layout->shape_layout().ToString().c_str(),
    289           ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
    290     }
    291   }
    292 
    293   // If any buffers in the operand occur in the output of the instruction, then
    294   // return an error. This case is not handled because such a constraint changes
    295   // layouts beyond this immediate use and is complicated to handle.
    296   if (OperandBufferForwarded(instruction, operand_no)) {
    297     return FailedPrecondition(
    298         "Cannot constraint layout of operand %lld of instruction %s "
    299         "because instruction forwards operand's LogicalBuffer(s)",
    300         operand_no, instruction->name().c_str());
    301   }
    302 
    303   auto key = std::make_pair(instruction, operand_no);
    304   auto iter = operand_constraints_.find(key);
    305   if (iter == operand_constraints_.end()) {
    306     auto pair = std::make_pair(
    307         key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
    308                                      instruction, operand_no, mandatory, dfs));
    309     iter = operand_constraints_.insert(pair).first;
    310   } else {
    311     iter->second =
    312         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
    313                                 operand_no, mandatory, dfs);
    314   }
    315   added_constraints_.push_back(&iter->second);
    316 
    317   return Status::OK();
    318 }
    319 
    320 Status LayoutConstraints::SetArrayOperandLayout(
    321     const Layout& layout, const HloInstruction* instruction, int64 operand_no,
    322     bool mandatory, bool dfs) {
    323   const HloInstruction* operand = instruction->operand(operand_no);
    324   TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
    325   Shape shape(operand->shape());
    326   *shape.mutable_layout() = layout;
    327   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
    328   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
    329 }
    330 
    331 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
    332                                           bool dfs) {
    333   VLOG(3) << "SetResultLayout : "
    334           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
    335 
    336   const ShapeLayout* curr_shape_layout = ResultLayout();
    337   if (curr_shape_layout != nullptr) {
    338     if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
    339       return FailedPrecondition(
    340           "Result of computation %s already has the layout constraint %s, "
    341           "cannot add incompatible constraint %s",
    342           computation_->name().c_str(), curr_shape_layout->ToString().c_str(),
    343           ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
    344     }
    345     // New constraint matches existing constraint. Nothing to do.
    346     return Status::OK();
    347   }
    348 
    349   result_constraint_.reset(
    350       new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
    351   added_constraints_.push_back(result_constraint_.get());
    352 
    353   return Status::OK();
    354 }
    355 
    356 Status LayoutConstraints::SetInstructionLayout(
    357     const Shape& shape_with_layout, const HloInstruction* instruction,
    358     bool mandatory, bool dfs) {
    359   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
    360           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
    361 
    362   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
    363     return FailedPrecondition(
    364         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
    365         instruction->name().c_str(),
    366         ShapeUtil::HumanString(instruction->shape()).c_str(),
    367         ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
    368   }
    369 
    370   // Create a BufferLayoutConstraint for each array shape in the output of the
    371   // instruction.
    372   return ShapeUtil::ForEachSubshapeWithStatus(
    373       shape_with_layout,
    374       [this, instruction, mandatory](const Shape& subshape,
    375                                      const ShapeIndex& index) -> Status {
    376         // The precondition for this method is that the instruction defines all
    377         // buffers in its output.
    378         auto buffers =
    379             points_to_analysis_.GetPointsToSet(instruction).element(index);
    380         CHECK_EQ(1, buffers.size());
    381         CHECK_EQ(buffers[0]->instruction(), instruction);
    382 
    383         if (ShapeUtil::IsArray(subshape)) {
    384           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
    385         } else {
    386           return Status::OK();
    387         }
    388       });
    389 }
    390 
    391 const Layout* LayoutConstraints::BufferLayout(
    392     const LogicalBuffer& buffer) const {
    393   if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
    394     return &constraint->layout();
    395   }
    396   return nullptr;
    397 }
    398 
    399 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
    400     const LogicalBuffer& buffer) const {
    401   auto it = buffer_constraints_.find(&buffer);
    402   return it == buffer_constraints_.end() ? nullptr : &it->second;
    403 }
    404 
    405 const ShapeLayout* LayoutConstraints::OperandLayout(
    406     const HloInstruction* instruction, int64 operand_no) const {
    407   if (const auto* constraint =
    408           GetOperandLayoutConstraint(instruction, operand_no)) {
    409     return &constraint->shape_layout();
    410   }
    411   return nullptr;
    412 }
    413 
    414 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
    415     const HloInstruction* instruction, int64 operand_no) const {
    416   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
    417   return it == operand_constraints_.end() ? nullptr : &it->second;
    418 }
    419 
    420 const ShapeLayout* LayoutConstraints::ResultLayout() const {
    421   return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
    422 }
    423 
    424 string LayoutConstraints::ToString() const {
    425   string output;
    426   tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
    427                                  computation_->name(), ":\n");
    428   for (auto* instruction : computation_->MakeInstructionPostOrder()) {
    429     tensorflow::strings::StrAppend(&output, "  ", instruction->ToShortString(),
    430                                    "\n");
    431     for (int64 i = 0; i < instruction->operand_count(); ++i) {
    432       if (OperandLayout(instruction, i) != nullptr) {
    433         tensorflow::strings::StrAppend(
    434             &output, "    operand (", i,
    435             "): ", OperandLayout(instruction, i)->ToString(), "\n");
    436       }
    437     }
    438     for (const LogicalBuffer* buffer :
    439          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
    440       if (BufferLayout(*buffer) != nullptr) {
    441         tensorflow::strings::StrAppend(
    442             &output, "    ", buffer->ToString(), " : ",
    443             LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
    444       }
    445     }
    446   }
    447 
    448   if (ResultLayout() != nullptr) {
    449     tensorflow::strings::StrAppend(&output, "  => ", ResultLayout()->ToString(),
    450                                    "\n");
    451   }
    452   return output;
    453 }
    454 
    455 Status LayoutAssignment::AddMandatoryConstraints(
    456     const ComputationLayout& computation_layout,
    457     const ChannelLayoutConstraints* channel_constraints,
    458     HloComputation* computation, LayoutConstraints* constraints) {
    459   VLOG(3) << "Adding mandatory layout constraints to computation "
    460           << computation->name();
    461 
    462   // Constrain layouts of instructions which define values with pre-existing
    463   // layouts.
    464   for (auto* instruction : computation->instructions()) {
    465     Shape const* shape_with_layout = nullptr;
    466     if (instruction->opcode() == HloOpcode::kInfeed) {
    467       // Infeed layouts must match the layout of the original inserted
    468       // instruction.
    469       // TODO(b/31425034): Change infeeds to be more like parameters, with
    470       // shapes in the ComputationLayout.
    471       DCHECK(!LayoutUtil::IsPadded(instruction->shape()));
    472       TF_RETURN_IF_ERROR(
    473           constraints->SetInstructionLayout(instruction->shape(), instruction));
    474     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
    475       // Constrain the input to the Outfeed instruction to be the expected
    476       // layout of the Outfeed.
    477       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    478           instruction->outfeed_shape(), instruction, 0));
    479     } else if (instruction->opcode() == HloOpcode::kParameter) {
    480       // Parameter layouts must match the respective layout in
    481       // ComputationLayout.
    482       shape_with_layout =
    483           &computation_layout.parameter_layout(instruction->parameter_number())
    484                .shape();
    485     }
    486     if (shape_with_layout != nullptr) {
    487       TF_RETURN_IF_ERROR(
    488           constraints->SetInstructionLayout(*shape_with_layout, instruction));
    489     }
    490 
    491     if (instruction->opcode() == HloOpcode::kSend ||
    492         instruction->opcode() == HloOpcode::kRecv) {
    493       CHECK(channel_constraints)
    494           << "Multi-module layout assignment requires ChannelLayoutConstraints";
    495       int64 channel_id = instruction->channel_id();
    496       if (!channel_constraints->IsChannelConstrained(channel_id)) {
    497         continue;
    498       }
    499       if (instruction->opcode() == HloOpcode::kSend) {
    500         // TODO(b/68493863): Change to use SetOperandLayout().
    501         const Shape send_buffer_shape = instruction->operand(0)->shape();
    502         TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape));
    503         Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel(
    504             send_buffer_shape, instruction->channel_id());
    505         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
    506             new_buffer_shape, instruction->operand(0)));
    507       } else {
    508         const Shape recv_buffer_shape =
    509             ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
    510         TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape));
    511         TF_ASSIGN_OR_RETURN(
    512             const LogicalBuffer* buffer,
    513             constraints->points_to_analysis().GetBufferDefinedAt(instruction,
    514                                                                  {0}));
    515         Shape new_shape = channel_constraints->LayoutShapeForChannel(
    516             recv_buffer_shape, instruction->channel_id());
    517         TF_RETURN_IF_ERROR(
    518             constraints->SetBufferLayout(new_shape.layout(), *buffer));
    519       }
    520     }
    521   }
    522 
    523   // Constrain layouts of instructions which call computations which have
    524   // already been assigned layouts. Instructions which call computations in a
    525   // parallel element-wise context (eg, map or reduce) do not need layout
    526   // constraints because they operate on scalars.
    527   for (auto* instruction : computation->instructions()) {
    528     if (instruction->opcode() == HloOpcode::kCall) {
    529       // kCall instruction operands and output must match the ComputationLayout
    530       // of the called computation.
    531       const ComputationLayout& called_computation_layout =
    532           FindOrDie(computation_layouts_, instruction->to_apply());
    533       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
    534           called_computation_layout.result_layout().shape(), instruction));
    535       TF_RET_CHECK(instruction->operand_count() ==
    536                    called_computation_layout.parameter_count());
    537       for (int64 i = 0; i < instruction->operand_count(); ++i) {
    538         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    539             called_computation_layout.parameter_layout(i).shape(), instruction,
    540             i));
    541       }
    542     } else if (instruction->opcode() == HloOpcode::kWhile) {
    543       // Layout of input and output of kWhile instruction must be equal and must
    544       // match both input and output of body computation. Also, the input of
    545       // condition computation must match kWhile layout.
    546       HloComputation* body = instruction->while_body();
    547       HloComputation* condition = instruction->while_condition();
    548       const HloInstruction* init = instruction->operand(0);
    549       const ComputationLayout& body_layout =
    550           FindOrDie(computation_layouts_, body);
    551       const ComputationLayout& condition_layout =
    552           FindOrDie(computation_layouts_, condition);
    553 
    554       // Check a few invariants irrespective of layout.
    555       CHECK_EQ(1, instruction->operand_count());
    556       CHECK_EQ(1, body->num_parameters());
    557       CHECK_EQ(1, condition->num_parameters());
    558       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
    559                                    body_layout.parameter_shape(0)));
    560       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
    561                                    condition_layout.parameter_shape(0)));
    562       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
    563 
    564       // Return error if earlier layout assignment of the embedded computations
    565       // has produced conflicting layouts.
    566       if (!ShapeUtil::Equal(body_layout.result_shape(),
    567                             body_layout.parameter_shape(0))) {
    568         return InternalError(
    569             "Parameter and result of body computation %s of while instruction "
    570             "%s have different layouts: %s vs %s",
    571             body->name().c_str(), instruction->name().c_str(),
    572             ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
    573             ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
    574       }
    575       if (!ShapeUtil::Equal(body->root_instruction()->shape(),
    576                             condition->parameter_instruction(0)->shape())) {
    577         return InternalError(
    578             "Parameter of condition computation %s of while instruction "
    579             "%s does not match body computation %s result: %s vs %s",
    580             condition->name().c_str(), instruction->name().c_str(),
    581             body->name().c_str(),
    582             ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
    583             ShapeUtil::HumanString(body_layout.result_shape()).c_str());
    584       }
    585 
    586       // Constrain the output and the operand of the while instruction to match
    587       // the computations.
    588       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
    589           body_layout.result_shape(), instruction));
    590       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    591           body_layout.result_shape(), instruction, 0));
    592     } else if (instruction->opcode() == HloOpcode::kConditional) {
    593       // The layout of the true and false computations must match, and must
    594       // be the layout of the kConditional instruction.
    595       TF_RET_CHECK(instruction->operand_count() == 3);
    596 
    597       HloComputation* true_computation = instruction->true_computation();
    598       HloComputation* false_computation = instruction->false_computation();
    599       const HloInstruction* true_operand = instruction->operand(1);
    600       const HloInstruction* false_operand = instruction->operand(2);
    601 
    602       TF_RET_CHECK(true_computation->num_parameters() == 1);
    603       TF_RET_CHECK(false_computation->num_parameters() == 1);
    604       ComputationLayout& true_computation_layout =
    605           FindOrDie(computation_layouts_, true_computation);
    606       ComputationLayout& false_computation_layout =
    607           FindOrDie(computation_layouts_, false_computation);
    608 
    609       DCHECK(ShapeUtil::Compatible(true_operand->shape(),
    610                                    true_computation_layout.parameter_shape(0)));
    611       DCHECK(ShapeUtil::Compatible(
    612           false_operand->shape(), false_computation_layout.parameter_shape(0)));
    613 
    614       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
    615           true_computation_layout.result_shape(), instruction));
    616       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    617           true_computation_layout.parameter_shape(0), instruction, 1,
    618           /*mandatory=*/true));
    619       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    620           false_computation_layout.parameter_shape(0), instruction, 2,
    621           /*mandatory=*/true));
    622     } else if (instruction->opcode() == HloOpcode::kCustomCall) {
    623       if (!CustomCallRequiresMajorFirstLayout(instruction)) {
    624         continue;
    625       }
    626       // Add constraints for kCustomCall instruction operands and instructions.
    627       // For now we only support major-first layouts for all inputs and outputs.
    628       Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
    629           instruction->shape().element_type(),
    630           AsInt64Slice(instruction->shape().dimensions()));
    631       TF_RETURN_IF_ERROR(
    632           constraints->SetInstructionLayout(result_shape, instruction));
    633       for (int64 i = 0; i < instruction->operand_count(); ++i) {
    634         const Shape& operand_shape = instruction->operand(i)->shape();
    635         // Opaque operands don't get a layout constraint.
    636         if (ShapeUtil::IsOpaque(operand_shape)) {
    637           continue;
    638         }
    639 
    640         Shape row_major_operand_shape =
    641             ShapeUtil::MakeShapeWithDescendingLayout(
    642                 operand_shape.element_type(),
    643                 AsInt64Slice(operand_shape.dimensions()));
    644         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
    645             row_major_operand_shape, instruction, i));
    646       }
    647     }
    648   }
    649 
    650   // Finally set the result layout to match ComputationLayout.
    651   return constraints->SetResultLayout(
    652       computation_layout.result_layout().shape());
    653 }
    654 
    655 namespace {
    656 
    657 // The operands of a call must match the layouts of parameters in the
    658 // ComputationLayout, and the call instruction itself must match the result
    659 // layout in the ComputationLayout.
    660 Status CheckCallLayout(HloInstruction* call,
    661                        const ComputationLayout& computation_layout) {
    662   HloComputation* computation = call->to_apply();
    663   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
    664   for (int64 i = 0; i < computation->num_parameters(); ++i) {
    665     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
    666         call->operand(i)->shape()));
    667   }
    668   TF_RET_CHECK(
    669       computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
    670   return Status::OK();
    671 }
    672 
    673 // Custom calls have fixed input and output layouts.
    674 Status CheckCustomCallLayout(HloInstruction* custom_call) {
    675   for (const HloInstruction* operand : custom_call->operands()) {
    676     TF_RET_CHECK(
    677         ShapeUtil::IsOpaque(operand->shape()) ||
    678         LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
    679   }
    680   TF_RET_CHECK(
    681       ShapeUtil::IsOpaque(custom_call->shape()) ||
    682       LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
    683   return Status::OK();
    684 }
    685 
    686 // For a while instruction, all the following layouts must be the same:
    687 //   (1) init operand
    688 //   (2) condition computation parameter
    689 //   (3) body computation parameter
    690 //   (4) body computation result
    691 //   (5) while instruction result
    692 Status CheckWhileLayout(HloInstruction* while_inst,
    693                         const ComputationLayout& condition_computation_layout,
    694                         const ComputationLayout& body_computation_layout) {
    695   auto init_shape = while_inst->operand(0)->shape();
    696   TF_RET_CHECK(
    697       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
    698           init_shape));
    699   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
    700       init_shape));
    701   TF_RET_CHECK(
    702       body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
    703   TF_RET_CHECK(
    704       LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
    705   return Status::OK();
    706 }
    707 
    708 Status CheckConditionalLayout(
    709     HloInstruction* instruction,
    710     const ComputationLayout& true_computation_layout,
    711     const ComputationLayout& false_computation_layout) {
    712   HloComputation* true_computation = instruction->true_computation();
    713   HloComputation* false_computation = instruction->false_computation();
    714   const HloInstruction* true_operand = instruction->operand(1);
    715   const HloInstruction* false_operand = instruction->operand(2);
    716 
    717   TF_RET_CHECK(true_computation_layout.result_layout() ==
    718                false_computation_layout.result_layout());
    719   TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
    720       instruction->shape()));
    721   TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
    722       true_computation->root_instruction()->shape()));
    723   TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
    724       instruction->shape()));
    725   TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
    726       false_computation->root_instruction()->shape()));
    727   TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape(
    728       true_operand->shape()));
    729   TF_RET_CHECK(
    730       false_computation_layout.parameter_layout(0).MatchesLayoutInShape(
    731           false_operand->shape()));
    732   return Status::OK();
    733 }
    734 
    735 // Fusion parameters must match the layout of the fusion instructions operands,
    736 // and the root of the fusion expression must match the layout of the fusion
    737 // instruction.
    738 Status CheckFusionLayout(HloInstruction* fusion) {
    739   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
    740 
    741   TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
    742       fusion->shape(), fusion->fused_expression_root()->shape()));
    743   for (int64 i = 0; i < fusion->operand_count(); ++i) {
    744     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
    745         fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
    746   }
    747   return Status::OK();
    748 }
    749 
    750 // The layout of a parameter must match the respective layout in the
    751 // computation's ComputationLayout.
    752 Status CheckParameterLayout(HloInstruction* parameter,
    753                             const ComputationLayout& computation_layout) {
    754   const ShapeLayout& parameter_layout =
    755       computation_layout.parameter_layout(parameter->parameter_number());
    756   if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
    757     return InternalError(
    758         "parameter instruction %s does not match layout of computation "
    759         "shape: %s",
    760         parameter->ToString().c_str(), parameter_layout.ToString().c_str());
    761   }
    762   return Status::OK();
    763 }
    764 
    765 // The layout of a constant instruction must match the layout of its literal.
    766 Status CheckConstantLayout(HloInstruction* constant) {
    767   if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
    768                                         constant->shape())) {
    769     return InternalError(
    770         "constant instruction %s does not match the layout of its literal %s",
    771         constant->ToString().c_str(),
    772         ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str());
    773   }
    774   return Status::OK();
    775 }
    776 
    777 }  // namespace
    778 
    779 Status LayoutAssignment::CheckLayouts(HloModule* module) {
    780   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
    781                       TuplePointsToAnalysis::Run(module));
    782   for (auto* computation : module->MakeNonfusionComputations()) {
    783     for (auto* instruction : computation->instructions()) {
    784       // Verify every instruction has a layout and the layout is valid for the
    785       // shape.
    786       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
    787       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
    788 
    789       // Use points-to analysis to verify that every subshape element in the
    790       // output of the instruction matches the layout of the logical buffer
    791       // which could be the source of the subshape value.
    792       const PointsToSet& points_to_set =
    793           points_to_analysis->GetPointsToSet(instruction);
    794       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
    795           [&instruction](ShapeIndex index,
    796                          const PointsToSet::BufferList& buffers) -> Status {
    797             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
    798               const Shape& instruction_subshape =
    799                   ShapeUtil::GetSubshape(instruction->shape(), index);
    800               for (const LogicalBuffer* buffer : buffers) {
    801                 if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
    802                   return InternalError(
    803                       "Layout of instruction %s at index {%s} does not match "
    804                       "source LogicalBuffer %s: %s vs %s",
    805                       instruction->name().c_str(),
    806                       tensorflow::str_util::Join(index, ",").c_str(),
    807                       buffer->ToString().c_str(),
    808                       ShapeUtil::HumanStringWithLayout(instruction_subshape)
    809                           .c_str(),
    810                       ShapeUtil::HumanStringWithLayout(buffer->shape())
    811                           .c_str());
    812                 }
    813               }
    814             }
    815             return Status::OK();
    816           }));
    817 
    818       // Verify instructions that have special layout constraints.
    819       switch (instruction->opcode()) {
    820         case HloOpcode::kCall:
    821           TF_RETURN_IF_ERROR(CheckCallLayout(
    822               instruction,
    823               FindOrDie(computation_layouts_, instruction->to_apply())));
    824           break;
    825         case HloOpcode::kCustomCall:
    826           if (CustomCallRequiresMajorFirstLayout(instruction)) {
    827             TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
    828           }
    829           break;
    830         case HloOpcode::kFusion:
    831           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
    832           break;
    833         case HloOpcode::kParameter:
    834           TF_RETURN_IF_ERROR(CheckParameterLayout(
    835               instruction,
    836               FindOrDie(computation_layouts_, instruction->parent())));
    837           break;
    838         case HloOpcode::kConstant:
    839           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
    840           break;
    841         case HloOpcode::kWhile:
    842           TF_RETURN_IF_ERROR(CheckWhileLayout(
    843               instruction,
    844               FindOrDie(computation_layouts_, instruction->while_condition()),
    845               FindOrDie(computation_layouts_, instruction->while_body())));
    846           break;
    847         case HloOpcode::kConditional:
    848           TF_RETURN_IF_ERROR(CheckConditionalLayout(
    849               instruction,
    850               FindOrDie(computation_layouts_, instruction->true_computation()),
    851               FindOrDie(computation_layouts_,
    852                         instruction->false_computation())));
    853           break;
    854         default:
    855           break;
    856       }
    857     }
    858   }
    859 
    860   // Finally verify the result layout matches the layout of the entry
    861   // computation root.
    862   TF_RET_CHECK(ShapeUtil::Equal(
    863       module->entry_computation()->root_instruction()->shape(),
    864       FindOrDie(computation_layouts_, module->entry_computation())
    865           .result_layout()
    866           .shape()));
    867 
    868   return Status::OK();
    869 }
    870 
    871 LayoutAssignment::LayoutAssignment(
    872     ComputationLayout* entry_computation_layout,
    873     ChannelLayoutConstraints* channel_constraints)
    874     : entry_computation_layout_(entry_computation_layout),
    875       channel_layout_constraints_(channel_constraints) {
    876   VLOG(1) << "entry computation layout given to layout assignment: "
    877           << entry_computation_layout_->ToString();
    878   // Layouts of all parameter instructions must be set.
    879   for (const ShapeLayout& parameter_layout :
    880        entry_computation_layout_->parameter_layouts()) {
    881     CHECK(parameter_layout.LayoutIsSet());
    882   }
    883   // If the result layout is not set, then choose the default.
    884   // TODO(b/29118294): Choose a better layout in this case.
    885   if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
    886     entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
    887   }
    888 }
    889 
    890 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
    891     const Layout& output_layout, const HloInstruction* instruction,
    892     int64 operand_no) {
    893   const HloInstruction* operand = instruction->operand(operand_no);
    894 
    895   CHECK(ShapeUtil::IsArray(instruction->shape()));
    896   CHECK(ShapeUtil::IsArray(operand->shape()));
    897 
    898   if (instruction->IsElementwiseOnOperand(operand_no) &&
    899       !ShapeUtil::IsScalar(operand->shape()) &&
    900       ShapeUtil::Rank(operand->shape()) ==
    901           ShapeUtil::Rank(instruction->shape())) {
    902     // Assign operands the same layout as the instruction, so that
    903     // 1) the elementwise operation can reuse its operand's buffer, and
    904     // 2) the input and output elements can reuse the same linear index.
    905     //
    906     // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
    907     // from assigning the same layout to input and output.
    908     return MakeUnique<Layout>(output_layout);
    909   }
    910 
    911   if (instruction->opcode() == HloOpcode::kReshape) {
    912     // Prefer the operand layout that makes the reshape an bitcast. If any
    913     // dimension bound is 1 in the operand shape, there may be several such
    914     // layouts. So if 'output_layout' is the default layout, try if the
    915     // reshape is a bitcast when using the same layout. This may avoid copy
    916     // operations. For similar reasons, if the operand and output have the same
    917     // rank, try to match the operand's layout to the output.
    918     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
    919         ShapeUtil::Rank(instruction->shape()) == 1) {
    920       // Don't assign a layout in case of R1 -> effective R1 reshape.
    921       return nullptr;
    922     }
    923     const Shape& output_shape = instruction->shape();
    924     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
    925         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
    926         LayoutUtil::MinorToMajor(output_layout));
    927     Shape operand_shape = operand->shape();
    928     *operand_shape.mutable_layout() =
    929         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
    930     if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
    931       return MakeUnique<Layout>(operand_shape.layout());
    932     }
    933     if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
    934       *operand_shape.mutable_layout() = output_layout;
    935       if (ShapeUtil::ReshapeIsBitcast(operand_shape,
    936                                       output_shape_with_layout)) {
    937         return MakeUnique<Layout>(output_layout);
    938       }
    939     }
    940     auto aligned_operand_shape =
    941         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
    942     if (aligned_operand_shape) {
    943       auto operand_layout = aligned_operand_shape.value().layout();
    944       TF_CHECK_OK(
    945           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
    946       return MakeUnique<Layout>(operand_layout);
    947     }
    948   }
    949 
    950   if (instruction->opcode() == HloOpcode::kTranspose) {
    951     // Pick the operand layout that makes the transpose a bitcast.
    952     int64 rank = ShapeUtil::Rank(instruction->shape());
    953     std::vector<int64> new_minor_to_major(rank);
    954     for (int64 i = 0; i < rank; ++i) {
    955       int64 output_dim = LayoutUtil::Minor(output_layout, i);
    956       int64 operand_dim = instruction->dimensions(output_dim);
    957       new_minor_to_major[i] = operand_dim;
    958     }
    959     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
    960     TF_CHECK_OK(
    961         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
    962     return MakeUnique<Layout>(operand_layout);
    963   }
    964 
    965   return nullptr;
    966 }
    967 
    968 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
    969     const Layout& operand_layout, const HloInstruction* user,
    970     int64 operand_no) {
    971   const HloInstruction* operand = user->operand(operand_no);
    972 
    973   CHECK(ShapeUtil::IsArray(user->shape()) &&
    974         ShapeUtil::IsArray(operand->shape()));
    975 
    976   if (user->IsElementwiseOnOperand(operand_no) &&
    977       !ShapeUtil::IsScalar(operand->shape()) &&
    978       ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
    979     // Assign users the same layout as the operand.
    980     return MakeUnique<Layout>(operand_layout);
    981   }
    982 
    983   if (user->opcode() == HloOpcode::kReshape) {
    984     // Prefer the user layout that makes the reshape an bitcast. If any
    985     // dimension bound is 1 in the user shape, there may be several such
    986     // layouts. So if 'operand_layout' is the default layout, try if the
    987     // reshape is a bitcast when using the same layout. This may avoid copy
    988     // operations. For similar reasons, if the operand and output have the same
    989     // rank, try to match the outputs's layout to the operand.
    990     if (ShapeUtil::Rank(operand->shape()) == 1 &&
    991         ShapeUtil::TrueRank(user->shape()) == 1) {
    992       // Don't assign a layout in case of R1 -> effective R1 reshape.
    993       return nullptr;
    994     }
    995     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
    996         operand->shape().element_type(),
    997         AsInt64Slice(operand->shape().dimensions()),
    998         LayoutUtil::MinorToMajor(operand_layout));
    999     Shape output_shape = user->shape();
   1000     *output_shape.mutable_layout() =
   1001         LayoutUtil::GetDefaultLayoutForShape(output_shape);
   1002     if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
   1003       return MakeUnique<Layout>(output_shape.layout());
   1004     }
   1005     if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
   1006       *output_shape.mutable_layout() = operand_layout;
   1007       if (ShapeUtil::ReshapeIsBitcast(output_shape,
   1008                                       operand_shape_with_layout)) {
   1009         return MakeUnique<Layout>(operand_layout);
   1010       }
   1011     }
   1012     auto aligned_user_shape =
   1013         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
   1014     if (aligned_user_shape) {
   1015       auto user_layout = aligned_user_shape.value().layout();
   1016       TF_CHECK_OK(
   1017           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
   1018       return MakeUnique<Layout>(user_layout);
   1019     }
   1020   }
   1021 
   1022   if (user->opcode() == HloOpcode::kTranspose) {
   1023     // Pick the user layout that makes the transpose a bitcast.
   1024     int64 rank = ShapeUtil::Rank(user->shape());
   1025     std::vector<int64> new_minor_to_major(rank);
   1026     auto inverse_dimensions = InversePermutation(user->dimensions());
   1027     for (int64 i = 0; i < rank; ++i) {
   1028       int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
   1029       int64 user_dim = inverse_dimensions[operand_dim];
   1030       new_minor_to_major[i] = user_dim;
   1031     }
   1032     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
   1033     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
   1034     return MakeUnique<Layout>(user_layout);
   1035   }
   1036 
   1037   return nullptr;
   1038 }
   1039 
   1040 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
   1041   // Gathers all initial constraints in a worklist and propagates them in
   1042   // depth-first order. DFS order seems to be better than BFS because a
   1043   // constraint is propagated as far as possible before propagating unrelated
   1044   // constraints which makes it less likely that conflicting constraints will be
   1045   // propagated to instructions. However, we should experiment with other orders
   1046   // too.
   1047   std::deque<const LayoutConstraint*> worklist;
   1048 
   1049   // Lambda for moving newly added constraints to the worklist.
   1050   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
   1051     // Add constraints to the front of the deque for DFS ordering.
   1052     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
   1053       if (constraint->dfs()) {
   1054         worklist.push_front(constraint);
   1055       } else {
   1056         worklist.push_back(constraint);
   1057       }
   1058     }
   1059   };
   1060   add_new_constraints_to_worklist();
   1061 
   1062   while (!worklist.empty()) {
   1063     const LayoutConstraint* layout_constraint = worklist.front();
   1064     worklist.pop_front();
   1065     VLOG(2) << "Propagating " << layout_constraint->ToString()
   1066             << " to its neighbors.";
   1067     if (auto* buffer_constraint =
   1068             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
   1069       TF_RETURN_IF_ERROR(
   1070           PropagateBufferConstraint(*buffer_constraint, constraints));
   1071     } else if (auto* operand_constraint =
   1072                    dynamic_cast<const OperandLayoutConstraint*>(
   1073                        layout_constraint)) {
   1074       TF_RETURN_IF_ERROR(
   1075           PropagateOperandConstraint(*operand_constraint, constraints));
   1076     } else if (auto* result_constraint =
   1077                    dynamic_cast<const ResultLayoutConstraint*>(
   1078                        layout_constraint)) {
   1079       TF_RETURN_IF_ERROR(
   1080           PropagateResultConstraint(*result_constraint, constraints));
   1081     } else {
   1082       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
   1083     }
   1084 
   1085     add_new_constraints_to_worklist();
   1086   }
   1087   return Status::OK();
   1088 }
   1089 
   1090 namespace {
   1091 
   1092 // Returns a vector containing all array-shaped uses (instruction and operand
   1093 // number) of the given logical buffer or its aliases.
   1094 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
   1095     const LogicalBuffer& buffer,
   1096     const TuplePointsToAnalysis& points_to_analysis) {
   1097   CHECK(buffer.IsArray());
   1098   std::vector<std::pair<const HloInstruction*, int64>> uses;
   1099   for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
   1100     if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) {
   1101       continue;
   1102     }
   1103     // This alias must be the top-level (index == {}) of the instruction's
   1104     // result because the instruction produces an array.
   1105     CHECK(buffer_alias.index().empty());
   1106 
   1107     // Add all uses of the instruction's output.
   1108     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
   1109       for (int64 operand_no :
   1110            user->OperandIndices(buffer_alias.instruction())) {
   1111         uses.emplace_back(user, operand_no);
   1112       }
   1113     }
   1114   }
   1115   return uses;
   1116 }
   1117 
   1118 }  // namespace
   1119 
   1120 Status LayoutAssignment::PropagateUseConstraintToDefs(
   1121     const ShapeLayout& shape_layout, const HloInstruction* instruction,
   1122     LayoutConstraints* constraints) {
   1123   // Try to set all logical buffers which may be sources of the given operand to
   1124   // match the given layout.
   1125   const PointsToSet& points_to_set =
   1126       constraints->points_to_analysis().GetPointsToSet(instruction);
   1127   return points_to_set.ForEachElementWithStatus(
   1128       [this, &shape_layout, constraints](
   1129           const ShapeIndex& index,
   1130           const PointsToSet::BufferList& buffers) -> Status {
   1131         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
   1132           for (const LogicalBuffer* buffer : buffers) {
   1133             if (constraints->BufferLayout(*buffer) == nullptr &&
   1134                 ShapeUtil::IsArray(buffer->shape())) {
   1135               TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
   1136                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
   1137                   *buffer, /*mandatory=*/true));
   1138             }
   1139           }
   1140         }
   1141         return Status::OK();
   1142       });
   1143 }
   1144 
   1145 Status LayoutAssignment::PropagateOperandConstraint(
   1146     const OperandLayoutConstraint& operand_constraint,
   1147     LayoutConstraints* constraints) {
   1148   // Try to set the layout of the logical buffers in the given operand to match
   1149   // the constrained layout. This avoids copies.
   1150   TF_RETURN_IF_ERROR(
   1151       PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
   1152                                    operand_constraint.operand(), constraints));
   1153 
   1154   // For array-shaped operands and user instructions try to pick a minimum cost
   1155   // layout. For example, if the operand of a elementwise instruction is
   1156   // constained to a certain layout we want the output of the instruction to
   1157   // have the same layout.
   1158   const HloInstruction* operand = operand_constraint.operand();
   1159   const HloInstruction* user = operand_constraint.instruction();
   1160   if (!ShapeUtil::IsArray(operand->shape()) ||
   1161       !ShapeUtil::IsArray(user->shape())) {
   1162     return Status::OK();
   1163   }
   1164 
   1165   // Only try to choose a low cost layout if the instruction 'user' defines its
   1166   // output (ie, doesn't forward a buffer from elsewhere).
   1167   if (constraints->OperandBufferForwarded(user,
   1168                                           operand_constraint.operand_no())) {
   1169     return Status::OK();
   1170   }
   1171   TF_ASSIGN_OR_RETURN(
   1172       const LogicalBuffer* buffer,
   1173       constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{}));
   1174 
   1175   if (constraints->BufferLayout(*buffer) == nullptr) {
   1176     std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
   1177         operand_constraint.shape_layout().layout(), user,
   1178         operand_constraint.operand_no());
   1179     if (layout != nullptr) {
   1180       TF_RETURN_IF_ERROR(
   1181           constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false));
   1182     }
   1183   }
   1184   return Status::OK();
   1185 }
   1186 
   1187 Status LayoutAssignment::PropagateBufferConstraint(
   1188     const BufferLayoutConstraint& buffer_constraint,
   1189     LayoutConstraints* constraints) {
   1190   // Only propagate array layouts.
   1191   const LogicalBuffer& buffer = buffer_constraint.buffer();
   1192   if (!buffer.IsArray()) {
   1193     return Status::OK();
   1194   }
   1195 
   1196   // If this buffer is the result of an array-shaped op (as opposed to an array
   1197   // element in a tuple) try to propagate the layout to its operands.
   1198   if (buffer.IsTopLevel()) {
   1199     const HloInstruction* instruction = buffer.instruction();
   1200     // Propagate the def-constraint on an instruction to the use-constraints on
   1201     // its operands (use-def propagation).
   1202     for (int64 operand_no = 0; operand_no < instruction->operand_count();
   1203          ++operand_no) {
   1204       if (constraints->OperandLayout(instruction, operand_no) == nullptr &&
   1205           ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) {
   1206         std::unique_ptr<Layout> operand_layout =
   1207             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
   1208                                                 instruction, operand_no);
   1209         if (operand_layout != nullptr) {
   1210           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
   1211               *operand_layout, instruction, operand_no, /*mandatory=*/true));
   1212         }
   1213       }
   1214     }
   1215   }
   1216   return PropagateBufferConstraintToUses(buffer_constraint, constraints);
   1217 }
   1218 
   1219 Status LayoutAssignment::PropagateBufferConstraintToUses(
   1220     const BufferLayoutConstraint& buffer_constraint,
   1221     LayoutConstraints* constraints) {
   1222   const LogicalBuffer& buffer = buffer_constraint.buffer();
   1223   TF_RET_CHECK(buffer.IsArray());
   1224 
   1225   // Propagate the layout to all array uses of the logical buffer. This skips
   1226   // uses of the buffer where the buffer is the element of a tuple.
   1227   for (const auto& user_operand_no :
   1228        GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
   1229     const HloInstruction* user = user_operand_no.first;
   1230     int64 operand_no = user_operand_no.second;
   1231     // Only add an operand constraint if the user does not forward the buffer
   1232     // because this case is not handled is SetOperandLayout.
   1233     if (constraints->OperandLayout(user, operand_no) == nullptr &&
   1234         !constraints->OperandBufferForwarded(user, operand_no)) {
   1235       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
   1236           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
   1237     }
   1238   }
   1239 
   1240   return Status::OK();
   1241 }
   1242 
   1243 Status LayoutAssignment::PropagateResultConstraint(
   1244     const ResultLayoutConstraint& result_constraint,
   1245     LayoutConstraints* constraints) {
   1246   // Propagate the use constraint of the root instruction up to the logical
   1247   // buffers which make up the result.
   1248   return PropagateUseConstraintToDefs(
   1249       result_constraint.shape_layout(),
   1250       constraints->computation()->root_instruction(), constraints);
   1251 }
   1252 
   1253 namespace {
   1254 
   1255 // Infers the layout of the array at the given index in the given instruction's
   1256 // output using points-to analysis. Precondition: The given instruction must
   1257 // not produce this array value (that is, the array is forwarded from the
   1258 // instruction's operands).
   1259 StatusOr<Layout> InferArrayLayout(
   1260     const TuplePointsToAnalysis& points_to_analysis,
   1261     HloInstruction* instruction, const ShapeIndex& index) {
   1262   // This function should only be called for array shapes which don't yet have
   1263   // layouts.
   1264   const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
   1265   TF_RET_CHECK(ShapeUtil::IsArray(subshape));
   1266   TF_RET_CHECK(!subshape.has_layout());
   1267 
   1268   // The instruction should not define the buffer at this index.
   1269   TF_RET_CHECK(
   1270       !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
   1271       << instruction->ToString();
   1272 
   1273   const auto& source_buffers =
   1274       points_to_analysis.GetPointsToSet(instruction).element(index);
   1275   TF_RET_CHECK(!source_buffers.empty());
   1276 
   1277   // Verify the layout is the same for every LogicalBuffer which this location
   1278   // ('instruction' and 'index') points to.
   1279   const Layout* first_buffer_layout = nullptr;
   1280   for (const LogicalBuffer* source_buffer : source_buffers) {
   1281     if (!source_buffer->shape().has_layout()) {
   1282       // This should not happen because we've assigned layouts to all
   1283       // instructions preceding this one.
   1284       return InternalError("LogicalBuffer %s does not have a layout",
   1285                            source_buffer->ToString().c_str());
   1286     }
   1287 
   1288     if (first_buffer_layout == nullptr) {
   1289       first_buffer_layout = &source_buffer->shape().layout();
   1290     } else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
   1291                                   *first_buffer_layout)) {
   1292       // The points-to set is ambiguous for this index and the different source
   1293       // buffers have different layouts. This case is possible in valid XLA
   1294       // computations because we do not propagate BufferLayoutConstraints to all
   1295       // LogicalBuffers which may alias the constrained LogicalBuffer at some
   1296       // point in the computation.
   1297       return FailedPrecondition(
   1298           "Array at index {%s} in instruction %s aliases buffers %s "
   1299           "and %s which have different layouts",
   1300           tensorflow::str_util::Join(index, ",").c_str(),
   1301           instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
   1302           source_buffer->ToString().c_str());
   1303     }
   1304   }
   1305 
   1306   return *first_buffer_layout;
   1307 }
   1308 
   1309 // For fusion instructions, set the layout of each fused parameter instruction
   1310 // to match the layout of its corresponding fusion instruction operand. Also,
   1311 // set the layout of the fused root to match the layout of the fusion
   1312 // instruction itself.
   1313 Status SetFusionLayouts(HloInstruction* fusion) {
   1314   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
   1315   for (auto* fused_instruction :
   1316        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
   1317     if (fused_instruction->opcode() == HloOpcode::kParameter) {
   1318       const HloInstruction* fusion_operand =
   1319           fusion->operand(fused_instruction->parameter_number());
   1320       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
   1321                                    fused_instruction->shape()));
   1322       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
   1323           fusion_operand->shape(), fused_instruction->mutable_shape()));
   1324     } else if (fused_instruction == fusion->fused_expression_root()) {
   1325       // The layout of the root of the fused expression must match the fusion
   1326       // instruction layout.
   1327       DCHECK(
   1328           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
   1329       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
   1330           fusion->shape(), fused_instruction->mutable_shape()));
   1331     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
   1332       // A GTE inherits its layout from its operand (which should ultimately be
   1333       // a parameter).
   1334       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
   1335           fused_instruction->operand(0)->shape().tuple_shapes(
   1336               fused_instruction->tuple_index()),
   1337           fused_instruction->mutable_shape()));
   1338     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
   1339       // Give constants the layout of their literal.
   1340       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
   1341           fused_instruction->literal().shape(),
   1342           fused_instruction->mutable_shape()));
   1343     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
   1344       // Nop; leave the infeed layout alone.
   1345     } else {
   1346       // Other instructions don't have layouts inside of fusion nodes.
   1347       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
   1348     }
   1349   }
   1350 
   1351   return Status::OK();
   1352 }
   1353 
   1354 }  // namespace
   1355 
   1356 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
   1357                                        HloComputation* computation) {
   1358   VLOG(2) << "Assigning layouts to computation: " << computation->name();
   1359   XLA_VLOG_LINES(2, computation->ToString());
   1360   XLA_VLOG_LINES(2, constraints.ToString());
   1361 
   1362   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
   1363     LayoutUtil::ClearLayout(instruction->mutable_shape());
   1364 
   1365     // Create a copy of an operand if the operand instruction's layout does not
   1366     // match the use constraint (OperandLayoutConstraint).
   1367     for (int64 operand_no = 0; operand_no < instruction->operand_count();
   1368          ++operand_no) {
   1369       const ShapeLayout* operand_layout =
   1370           constraints.OperandLayout(instruction, operand_no);
   1371       if (operand_layout != nullptr) {
   1372         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
   1373                                                       instruction, operand_no));
   1374       }
   1375     }
   1376 
   1377     // Set the layouts of the array shapes this instruction defines as indicated
   1378     // by the respective BufferLayoutConstraints. Any array shapes in the output
   1379     // of the instruction which are not defined by the instruction (eg, array
   1380     // elements in a Tuple instruction) will be assigned below via inference.
   1381     for (const LogicalBuffer* buffer :
   1382          constraints.points_to_analysis().GetBuffersDefinedByInstruction(
   1383              instruction)) {
   1384       if (!ShapeUtil::IsArray(buffer->shape())) {
   1385         continue;
   1386       }
   1387 
   1388       TF_RET_CHECK(buffer->instruction() == instruction);
   1389       const Layout* buffer_layout = constraints.BufferLayout(*buffer);
   1390       TF_RET_CHECK(buffer_layout != nullptr);
   1391 
   1392       if (instruction->opcode() == HloOpcode::kConstant) {
   1393         // For constants, we also need to change the layout of the internal
   1394         // literal.
   1395         instruction->RelayoutConstant(*buffer_layout, buffer->index());
   1396       } else {
   1397         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
   1398             instruction->mutable_shape(), buffer->index());
   1399         *buffer_subshape->mutable_layout() = *buffer_layout;
   1400       }
   1401     }
   1402 
   1403     // Any remaining layouts in the output of the instruction must be
   1404     // inferrable using points-to analysis.
   1405     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
   1406         instruction->mutable_shape(),
   1407         [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
   1408           if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) {
   1409             return Status::OK();
   1410           }
   1411           // Set Layout of subshape to match layout of LogicalBuffer which
   1412           // produces it.
   1413           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
   1414                               InferArrayLayout(constraints.points_to_analysis(),
   1415                                                instruction, index));
   1416           return Status::OK();
   1417         }));
   1418 
   1419     // Fusion instructions require some layouts to be set on fused instructions
   1420     // inside the fusion instruction.
   1421     if (instruction->opcode() == HloOpcode::kFusion) {
   1422       TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
   1423     }
   1424 
   1425     // Execute extra verification step once the layout has been finalized.
   1426     TF_RETURN_IF_ERROR(Verify(instruction));
   1427 
   1428     // Verify all layouts in the shape have been set.
   1429     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
   1430   }
   1431 
   1432   // Copy the root instruction's result if its layout does not match the result
   1433   // layout constraint.
   1434   if (constraints.ResultLayout() != nullptr &&
   1435       !constraints.ResultLayout()->MatchesLayoutInShape(
   1436           computation->root_instruction()->shape())) {
   1437     TF_ASSIGN_OR_RETURN(
   1438         HloInstruction * new_root,
   1439         CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
   1440                                 computation->root_instruction()));
   1441     computation->set_root_instruction(new_root);
   1442   }
   1443 
   1444   return Status::OK();
   1445 }
   1446 
   1447 Status LayoutAssignment::RunOnComputation(
   1448     const ComputationLayout& computation_layout,
   1449     const TuplePointsToAnalysis& points_to_analysis,
   1450     HloComputation* computation,
   1451     ChannelLayoutConstraints* channel_constraints) {
   1452   DCHECK(computation_layout.LayoutIsSet());
   1453   InsertOrDie(&computation_layouts_, computation, computation_layout);
   1454   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
   1455           << ")";
   1456   VLOG(2) << "  ComputationLayout = " << computation_layout.ToString();
   1457 
   1458   // Construct LayoutConstraints with all layout constraints of the computation.
   1459   LayoutConstraints constraints(points_to_analysis, computation);
   1460 
   1461   // Add constraints required for correctness on all backends (eg, entry
   1462   // parameter layout constraints).
   1463   TF_RETURN_IF_ERROR(AddMandatoryConstraints(
   1464       computation_layout, channel_constraints, computation, &constraints));
   1465 
   1466   // Add any backend-specific constraints.
   1467   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
   1468 
   1469   // Propagates layouts from mandatory and backend constraints.
   1470   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
   1471 
   1472   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
   1473   // layout and propagate the change.
   1474   while (!constraints.unconstrained_buffer_ids().empty()) {
   1475     int unconstrained_count = constraints.unconstrained_buffer_ids().size();
   1476 
   1477     // Arbitrarily pick the first unconstrained buffer and give it the default
   1478     // layout (or the literal layout, in case of constants). By construction
   1479     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
   1480     const LogicalBuffer& buffer = points_to_analysis.GetBuffer(
   1481         *constraints.unconstrained_buffer_ids().begin());
   1482     const HloInstruction* instruction = buffer.instruction();
   1483     Layout new_layout =
   1484         instruction->opcode() == HloOpcode::kConstant
   1485             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
   1486                                      buffer.index())
   1487                   .layout()
   1488             : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
   1489     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
   1490                                                    /*mandatory=*/false));
   1491 
   1492     TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
   1493 
   1494     // To verify progress has been made, check that the number of unconstrained
   1495     // buffers has been reduced.
   1496     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
   1497              unconstrained_count);
   1498   }
   1499 
   1500   // All logical buffers should have constraints at this point. All that
   1501   // remains is assign the constraints to the buffers and infer layouts for
   1502   // aliased buffers.
   1503   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
   1504 
   1505   // Record the layouts assigned for any communication ops in
   1506   // channel_constraints so that they are constrained for future modules.
   1507   for (HloInstruction* instruction : computation->instructions()) {
   1508     if (instruction->opcode() == HloOpcode::kSend) {
   1509       channel_constraints->ConstrainChannel(
   1510           instruction->channel_id(), instruction->operand(0)->shape().layout());
   1511     } else if (instruction->opcode() == HloOpcode::kRecvDone) {
   1512       channel_constraints->ConstrainChannel(instruction->channel_id(),
   1513                                             instruction->shape().layout());
   1514     }
   1515   }
   1516   return Status::OK();
   1517 }
   1518 
   1519 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
   1520   VLOG(2) << "Running layout assignment on module " << module->name();
   1521   XLA_VLOG_LINES(3, module->ToString());
   1522   if (VLOG_IS_ON(10)) {
   1523     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
   1524                                 "before layout assignment",
   1525                                 module->config().debug_options());
   1526   }
   1527 
   1528   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
   1529                       TuplePointsToAnalysis::Run(module));
   1530 
   1531   // Assign layouts to computations in an order such that a callee computation
   1532   // is handled before its caller computation. This ensures that the layout of
   1533   // all callers of a computation will agree.
   1534   std::list<HloComputation*> computation_post_order =
   1535       module->MakeComputationPostOrder();
   1536   for (auto* computation : module->MakeComputationPostOrder()) {
   1537     if (computation->IsFusionComputation()) {
   1538       continue;
   1539     }
   1540     // Clear existing layouts of the instructions.  All layouts must be assigned
   1541     // by the LayoutAssignment pass, except for those on infeeds, parameters,
   1542     // and the computation result. The latter two are specified in
   1543     // computation_layout, so we only need to keep the existing layouts for
   1544     // infeeds.  Clearing the layouts here avoids hiding potential bugs in the
   1545     // layout assignment pass that may accidently use the existing layout.
   1546     for (HloInstruction* instruction : computation->instructions()) {
   1547       if (instruction->opcode() != HloOpcode::kInfeed) {
   1548         LayoutUtil::ClearLayout(instruction->mutable_shape());
   1549       }
   1550     }
   1551     if (computation == module->entry_computation()) {
   1552       TF_RETURN_IF_ERROR(RunOnComputation(
   1553           *entry_computation_layout_, *points_to_analysis,
   1554           module->entry_computation(), channel_layout_constraints_));
   1555     } else {
   1556       ComputationLayout computation_layout(computation->ComputeProgramShape());
   1557       // Setting all embedded computations to the default layout is potentially
   1558       // suboptimal.
   1559       computation_layout.SetToDefaultLayout();
   1560       TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
   1561                                           *points_to_analysis, computation,
   1562                                           channel_layout_constraints_));
   1563     }
   1564   }
   1565 
   1566   TF_RETURN_IF_ERROR(CheckLayouts(module));
   1567 
   1568   VLOG(3) << "After layout assignment:";
   1569   XLA_VLOG_LINES(3, module->ToString());
   1570   if (VLOG_IS_ON(10)) {
   1571     hlo_graph_dumper::DumpGraph(*module->entry_computation(),
   1572                                 "after layout assignment",
   1573                                 module->config().debug_options());
   1574   }
   1575 
   1576   // All layouts are reset then reassigned by this pass.
   1577   return true;
   1578 }
   1579 
   1580 }  // namespace xla
   1581